1use std::fmt;
2use std::iter::FromIterator;
3use std::ops::*;
4
5use arrayfire as af;
6use async_trait::async_trait;
7use destream::{de, en};
8use futures::TryFutureExt;
9use get_size::GetSize;
10use num_traits::{FromPrimitive, ToPrimitive};
11use number_general::*;
12use safecast::{as_type, CastFrom, CastInto};
13use serde::de::{Deserialize, Deserializer};
14use serde::ser::{Serialize, Serializer};
15
16use super::ext::*;
17use super::{error, Complex, Result};
18
19pub fn product_dtype(array_dtype: NumberType) -> NumberType {
21 use {ComplexType as CT, FloatType as FT, IntType as IT, NumberType as NT, UIntType as UT};
22
23 match array_dtype {
24 NT::Bool => ArrayExt::<bool>::product_dtype(),
25 NT::Complex(ct) => match ct {
26 CT::C32 => ArrayExt::<Complex<f32>>::product_dtype(),
27 CT::C64 => ArrayExt::<Complex<f64>>::product_dtype(),
28 CT::Complex => ArrayExt::<Complex<f64>>::product_dtype(),
29 },
30 NT::Float(ft) => match ft {
31 FT::F32 => ArrayExt::<f32>::product_dtype(),
32 FT::F64 => ArrayExt::<f64>::product_dtype(),
33 FT::Float => ArrayExt::<f64>::product_dtype(),
34 },
35 NT::Int(it) => match it {
36 IT::I8 => ArrayExt::<i16>::product_dtype(),
37 IT::I16 => ArrayExt::<i16>::product_dtype(),
38 IT::I32 => ArrayExt::<i32>::product_dtype(),
39 IT::I64 => ArrayExt::<i64>::product_dtype(),
40 IT::Int => ArrayExt::<i64>::product_dtype(),
41 },
42 NT::UInt(ut) => match ut {
43 UT::U8 => ArrayExt::<u8>::product_dtype(),
44 UT::U16 => ArrayExt::<u16>::product_dtype(),
45 UT::U32 => ArrayExt::<u32>::product_dtype(),
46 UT::U64 => ArrayExt::<u64>::product_dtype(),
47 UT::UInt => ArrayExt::<u64>::product_dtype(),
48 },
49 NT::Number => ArrayExt::<f64>::product_dtype(),
50 }
51}
52
53pub fn sum_dtype(array_dtype: NumberType) -> NumberType {
55 use {ComplexType as CT, FloatType as FT, IntType as IT, NumberType as NT, UIntType as UT};
56
57 match array_dtype {
58 NT::Bool => ArrayExt::<bool>::sum_dtype(),
59 NT::Complex(ct) => match ct {
60 CT::C32 => ArrayExt::<Complex<f32>>::sum_dtype(),
61 CT::C64 => ArrayExt::<Complex<f64>>::sum_dtype(),
62 CT::Complex => ArrayExt::<Complex<f64>>::sum_dtype(),
63 },
64 NT::Float(ft) => match ft {
65 FT::F32 => ArrayExt::<f32>::sum_dtype(),
66 FT::F64 => ArrayExt::<f64>::sum_dtype(),
67 FT::Float => ArrayExt::<f64>::sum_dtype(),
68 },
69 NT::Int(it) => match it {
70 IT::I8 => ArrayExt::<i16>::sum_dtype(),
71 IT::I16 => ArrayExt::<i16>::sum_dtype(),
72 IT::I32 => ArrayExt::<i32>::sum_dtype(),
73 IT::I64 => ArrayExt::<i64>::sum_dtype(),
74 IT::Int => ArrayExt::<i64>::sum_dtype(),
75 },
76 NT::UInt(ut) => match ut {
77 UT::U8 => ArrayExt::<u8>::sum_dtype(),
78 UT::U16 => ArrayExt::<u16>::sum_dtype(),
79 UT::U32 => ArrayExt::<u32>::sum_dtype(),
80 UT::U64 => ArrayExt::<u64>::sum_dtype(),
81 UT::UInt => ArrayExt::<u64>::sum_dtype(),
82 },
83 NT::Number => ArrayExt::<f64>::sum_dtype(),
84 }
85}
86
87macro_rules! dispatch {
88 ($this:expr, $call:expr) => {
89 match $this {
90 Array::Bool(this) => $call(this),
91 Array::C32(this) => $call(this),
92 Array::C64(this) => $call(this),
93 Array::F32(this) => $call(this),
94 Array::F64(this) => $call(this),
95 Array::I16(this) => $call(this),
96 Array::I32(this) => $call(this),
97 Array::I64(this) => $call(this),
98 Array::U8(this) => $call(this),
99 Array::U16(this) => $call(this),
100 Array::U32(this) => $call(this),
101 Array::U64(this) => $call(this),
102 }
103 };
104}
105
106macro_rules! reduce {
107 ($this:expr, $reduce:expr, $stride:expr) => {
108 match $this {
109 Array::Bool(this) => $reduce(this, $stride),
110 Array::C32(this) => $reduce(this, $stride),
111 Array::C64(this) => $reduce(this, $stride),
112 Array::F32(this) => $reduce(this, $stride),
113 Array::F64(this) => $reduce(this, $stride),
114 Array::I16(this) => $reduce(this, $stride),
115 Array::I32(this) => $reduce(this, $stride),
116 Array::I64(this) => $reduce(this, $stride),
117 Array::U8(this) => $reduce(this, $stride),
118 Array::U16(this) => $reduce(this, $stride),
119 Array::U32(this) => $reduce(this, $stride),
120 Array::U64(this) => $reduce(this, $stride),
121 }
122 };
123}
124
125macro_rules! trig {
126 ($fun:ident) => {
127 pub fn $fun(&self) -> Array {
128 fn $fun<T>(this: &ArrayExt<T>) -> Array
129 where
130 T: af::HasAfEnum + Default,
131 ArrayExt<T>: ArrayInstanceTrig<T>,
132 Array: From<ArrayExt<T::UnaryOutType>>,
133 {
134 this.$fun().into()
135 }
136
137 dispatch!(self, $fun)
138 }
139 };
140}
141
142#[derive(Clone)]
144pub enum Array {
145 Bool(ArrayExt<bool>),
146 C32(ArrayExt<Complex<f32>>),
147 C64(ArrayExt<Complex<f64>>),
148 F32(ArrayExt<f32>),
149 F64(ArrayExt<f64>),
150 I16(ArrayExt<i16>),
151 I32(ArrayExt<i32>),
152 I64(ArrayExt<i64>),
153 U8(ArrayExt<u8>),
154 U16(ArrayExt<u16>),
155 U32(ArrayExt<u32>),
156 U64(ArrayExt<u64>),
157}
158
159impl GetSize for Array {
160 fn get_size(&self) -> usize {
161 self.dtype().size() * self.len()
162 }
163}
164
165impl Array {
166 pub fn type_cast<T: af::HasAfEnum>(&self) -> ArrayExt<T> {
168 dispatch!(self, ArrayExt::type_cast)
169 }
170
171 pub fn concatenate(left: &Array, right: &Array) -> Array {
173 use Array::*;
174 match (left, right) {
175 (Bool(l), Bool(r)) => Bool(ArrayExt::concatenate(l, r)),
176
177 (F32(l), F32(r)) => F32(ArrayExt::concatenate(l, r)),
178 (F64(l), F64(r)) => F64(ArrayExt::concatenate(l, r)),
179
180 (C32(l), C32(r)) => C32(ArrayExt::concatenate(l, r)),
181 (C64(l), C64(r)) => C64(ArrayExt::concatenate(l, r)),
182
183 (I16(l), I16(r)) => I16(ArrayExt::concatenate(l, r)),
184 (I32(l), I32(r)) => I32(ArrayExt::concatenate(l, r)),
185 (I64(l), I64(r)) => I64(ArrayExt::concatenate(l, r)),
186
187 (U8(l), U8(r)) => U8(ArrayExt::concatenate(l, r)),
188 (U16(l), U16(r)) => U16(ArrayExt::concatenate(l, r)),
189 (U32(l), U32(r)) => U32(ArrayExt::concatenate(l, r)),
190 (U64(l), U64(r)) => U64(ArrayExt::concatenate(l, r)),
191
192 (l, r) if l.dtype() > r.dtype() => Array::concatenate(l, &r.cast_into(l.dtype())),
193 (l, r) if l.dtype() < r.dtype() => Array::concatenate(&l.cast_into(r.dtype()), r),
194
195 (l, r) => unreachable!("concatenate {}, {}", l, r),
196 }
197 }
198
199 pub fn constant(value: Number, length: usize) -> Array {
201 use number_general::Complex;
202 use Array::*;
203
204 match value {
205 Number::Bool(b) => {
206 let b: bool = b.into();
207 Bool(ArrayExt::constant(b, length))
208 }
209 Number::Complex(c) => match c {
210 Complex::C32(c) => C32(ArrayExt::constant(c, length)),
211 Complex::C64(c) => C64(ArrayExt::constant(c, length)),
212 },
213 Number::Float(f) => match f {
214 Float::F32(f) => F32(ArrayExt::constant(f, length)),
215 Float::F64(f) => F64(ArrayExt::constant(f, length)),
216 },
217 Number::Int(i) => match i {
218 Int::I16(i) => I16(ArrayExt::constant(i, length)),
219 Int::I32(i) => I32(ArrayExt::constant(i, length)),
220 Int::I64(i) => I64(ArrayExt::constant(i, length)),
221 other => panic!("ArrayFire does not support {}", other),
222 },
223 Number::UInt(u) => match u {
224 UInt::U8(u) => U8(ArrayExt::constant(u, length)),
225 UInt::U16(u) => U16(ArrayExt::constant(u, length)),
226 UInt::U32(u) => U32(ArrayExt::constant(u, length)),
227 UInt::U64(u) => U64(ArrayExt::constant(u, length)),
228 },
229 }
230 }
231
232 pub fn random_normal(dtype: FloatType, length: usize) -> Array {
234 match dtype {
235 FloatType::F32 => Array::F32(ArrayExt::random_normal(length)),
236 _ => Array::F64(ArrayExt::random_normal(length)),
237 }
238 }
239
240 pub fn random_uniform(dtype: FloatType, length: usize) -> Array {
242 match dtype {
243 FloatType::F32 => Array::F32(ArrayExt::random_uniform(length)),
244 _ => Array::F64(ArrayExt::random_uniform(length)),
245 }
246 }
247
248 pub fn dtype(&self) -> NumberType {
250 use number_general::DType;
251 use Array::*;
252
253 match self {
254 Bool(_) => bool::dtype(),
255 C32(_) => Complex::<f32>::dtype(),
256 C64(_) => Complex::<f64>::dtype(),
257 F32(_) => f32::dtype(),
258 F64(_) => f64::dtype(),
259 I16(_) => i16::dtype(),
260 I32(_) => i32::dtype(),
261 I64(_) => i64::dtype(),
262 U8(_) => u8::dtype(),
263 U16(_) => u16::dtype(),
264 U32(_) => u32::dtype(),
265 U64(_) => u64::dtype(),
266 }
267 }
268
269 pub fn cast_into(&self, dtype: NumberType) -> Array {
271 use {ComplexType as CT, FloatType as FT, IntType as IT, NumberType as NT, UIntType as UT};
272
273 match dtype {
274 NT::Bool => Self::Bool(self.type_cast()),
275 NT::Complex(ct) => match ct {
276 CT::C32 => Self::C32(self.type_cast()),
277 CT::C64 => Self::C64(self.type_cast()),
278 CT::Complex => Self::C64(self.type_cast()),
279 },
280 NT::Float(ft) => match ft {
281 FT::F32 => Self::F32(self.type_cast()),
282 FT::F64 => Self::F64(self.type_cast()),
283 FT::Float => Self::F64(self.type_cast()),
284 },
285 NT::Int(it) => match it {
286 IT::I16 => Self::I16(self.type_cast()),
287 IT::I32 => Self::I32(self.type_cast()),
288 IT::I64 => Self::I64(self.type_cast()),
289 IT::Int => Self::I64(self.type_cast()),
290 other => panic!("ArrayFire does not support {}", other),
291 },
292 NT::UInt(ut) => match ut {
293 UT::U8 => Self::U8(self.type_cast()),
294 UT::U16 => Self::U16(self.type_cast()),
295 UT::U32 => Self::U32(self.type_cast()),
296 UT::U64 => Self::U64(self.type_cast()),
297 UT::UInt => Self::U64(self.type_cast()),
298 },
299 NT::Number => self.clone(),
300 }
301 }
302
303 pub fn to_vec(&self) -> Vec<Number> {
305 fn to_vec<T>(this: &ArrayExt<T>) -> Vec<Number>
306 where
307 T: af::HasAfEnum + Clone + Default,
308 Number: From<T>,
309 {
310 this.to_vec().into_iter().map(Number::from).collect()
311 }
312
313 dispatch!(self, to_vec)
314 }
315
316 pub fn abs(&self) -> Array {
318 use Array::*;
319 match self {
320 C32(c) => F32(c.abs()),
321 C64(c) => F64(c.abs()),
322 F32(f) => F32(f.abs()),
323 F64(f) => F64(f.abs()),
324 I16(i) => I16(i.abs()),
325 I32(i) => I32(i.abs()),
326 I64(i) => I64(i.abs()),
327 other => other.clone(),
328 }
329 }
330
331 pub fn all(&self) -> bool {
333 dispatch!(self, ArrayExt::all)
334 }
335
336 pub fn any(&self) -> bool {
338 dispatch!(self, ArrayExt::any)
339 }
340
341 pub fn and(&self, other: &Array) -> Array {
343 let this: ArrayExt<bool> = self.type_cast();
344 let that: ArrayExt<bool> = other.type_cast();
345 Array::Bool(this.and(&that))
346 }
347
348 pub fn and_const(&self, other: Number) -> Array {
350 let this: ArrayExt<bool> = self.type_cast();
351 let that: ArrayExt<bool> = ArrayExt::from(&[other.cast_into()][..]);
352 Array::Bool(this.and(&that))
353 }
354
355 pub fn argmax(&self) -> (usize, Number) {
357 fn imax<T: af::HasAfEnum>(x: &ArrayExt<T>) -> (usize, Number)
358 where
359 ArrayExt<T>: ArrayInstanceIndex,
360 Number: From<<ArrayExt<T> as ArrayInstance>::DType>,
361 {
362 let (i, max) = x.argmax();
363 (i, max.into())
364 }
365
366 dispatch!(self, imax)
367 }
368
369 pub fn eq(&self, other: &Array) -> Array {
371 use Array::*;
372 match (self, other) {
373 (Bool(l), Bool(r)) => Bool(l.eq(r.deref())),
374 (C32(l), C32(r)) => Bool(l.eq(r.deref())),
375 (C64(l), C64(r)) => Bool(l.eq(r.deref())),
376 (F32(l), F32(r)) => Bool(l.eq(r.deref())),
377 (F64(l), F64(r)) => Bool(l.eq(r.deref())),
378 (I16(l), I16(r)) => Bool(l.eq(r.deref())),
379 (I32(l), I32(r)) => Bool(l.eq(r.deref())),
380 (I64(l), I64(r)) => Bool(l.eq(r.deref())),
381 (U8(l), U8(r)) => Bool(l.eq(r.deref())),
382 (U16(l), U16(r)) => Bool(l.eq(r.deref())),
383 (U32(l), U32(r)) => Bool(l.eq(r.deref())),
384 (U64(l), U64(r)) => Bool(l.eq(r.deref())),
385 (l, r) => match (l.dtype(), r.dtype()) {
386 (l_dtype, r_dtype) if l_dtype > r_dtype => l.eq(&r.cast_into(l_dtype)),
387 (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).eq(r),
388 (l, r) => unreachable!("{} equal to {}", l, r),
389 },
390 }
391 }
392
393 pub fn eq_const(&self, other: Number) -> Array {
395 use number_general::Complex;
396 match (self, other) {
397 (Self::Bool(l), Number::Bool(r)) => Self::Bool(l.eq(&bool::from(r))),
398
399 (Self::C32(l), Number::Complex(Complex::C32(r))) => Self::Bool(l.eq(&r)),
400 (Self::C64(l), Number::Complex(Complex::C64(r))) => Self::Bool(l.eq(&r)),
401
402 (Self::F32(l), Number::Float(Float::F32(r))) => Self::Bool(l.eq(&r)),
403 (Self::F64(l), Number::Float(Float::F64(r))) => Self::Bool(l.eq(&r)),
404
405 (Self::I16(l), Number::Int(Int::I16(r))) => Self::Bool(l.eq(&r)),
406 (Self::I32(l), Number::Int(Int::I32(r))) => Self::Bool(l.eq(&r)),
407 (Self::I64(l), Number::Int(Int::I64(r))) => Self::Bool(l.eq(&r)),
408
409 (Self::U8(l), Number::UInt(UInt::U8(r))) => Self::Bool(l.eq(&r)),
410 (Self::U16(l), Number::UInt(UInt::U16(r))) => Self::Bool(l.eq(&r)),
411 (Self::U32(l), Number::UInt(UInt::U32(r))) => Self::Bool(l.eq(&r)),
412 (Self::U64(l), Number::UInt(UInt::U64(r))) => Self::Bool(l.eq(&r)),
413
414 (l, r) => match (l.dtype(), r.class()) {
415 (l_dtype, r_dtype) if l_dtype > r_dtype => l.eq_const(r.into_type(l_dtype)),
416 (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).eq_const(r),
417 (l, r) => unreachable!("{} equal to {}", l, r),
418 },
419 }
420 }
421
422 pub fn exp(&self) -> Array {
424 fn exp<T>(this: &ArrayExt<T>) -> Array
425 where
426 T: af::HasAfEnum + Default,
427 Array: From<ArrayExt<T::UnaryOutType>>,
428 {
429 this.exp().into()
430 }
431
432 dispatch!(self, exp)
433 }
434
435 pub fn gt(&self, other: &Array) -> Array {
437 use Array::*;
438 match (self, other) {
439 (Bool(l), Bool(r)) => Bool(l.gt(r.deref())),
440 (C32(l), C32(r)) => Bool(l.gt(r.deref())),
441 (C64(l), C64(r)) => Bool(l.gt(r.deref())),
442 (F32(l), F32(r)) => Bool(l.gt(r.deref())),
443 (F64(l), F64(r)) => Bool(l.gt(r.deref())),
444 (I16(l), I16(r)) => Bool(l.gt(r.deref())),
445 (I32(l), I32(r)) => Bool(l.gt(r.deref())),
446 (I64(l), I64(r)) => Bool(l.gt(r.deref())),
447 (U8(l), U8(r)) => Bool(l.gt(r.deref())),
448 (U16(l), U16(r)) => Bool(l.gt(r.deref())),
449 (U32(l), U32(r)) => Bool(l.gt(r.deref())),
450 (U64(l), U64(r)) => Bool(l.gt(r.deref())),
451 (l, r) => match (l.dtype(), r.dtype()) {
452 (l_dtype, r_dtype) if l_dtype > r_dtype => l.gt(&r.cast_into(l_dtype)),
453 (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).gt(r),
454 (l, r) => unreachable!("{} greater than {}", l, r),
455 },
456 }
457 }
458
459 pub fn gt_const(&self, other: Number) -> Array {
461 use number_general::Complex;
462 match (self, other) {
463 (Self::Bool(l), Number::Bool(r)) => Self::Bool(l.gt(&bool::from(r))),
464 (Self::C32(l), Number::Complex(Complex::C32(r))) => Self::Bool(l.gt(&r)),
465 (Self::C64(l), Number::Complex(Complex::C64(r))) => Self::Bool(l.gt(&r)),
466 (Self::F32(l), Number::Float(Float::F32(r))) => Self::Bool(l.gt(&r)),
467 (Self::F64(l), Number::Float(Float::F64(r))) => Self::Bool(l.gt(&r)),
468 (Self::I16(l), Number::Int(Int::I16(r))) => Self::Bool(l.gt(&r)),
469 (Self::I32(l), Number::Int(Int::I32(r))) => Self::Bool(l.gt(&r)),
470 (Self::I64(l), Number::Int(Int::I64(r))) => Self::Bool(l.gt(&r)),
471 (Self::U8(l), Number::UInt(UInt::U8(r))) => Self::Bool(l.gt(&r)),
472 (Self::U16(l), Number::UInt(UInt::U16(r))) => Self::Bool(l.gt(&r)),
473 (Self::U32(l), Number::UInt(UInt::U32(r))) => Self::Bool(l.gt(&r)),
474 (Self::U64(l), Number::UInt(UInt::U64(r))) => Self::Bool(l.gt(&r)),
475 (l, r) => match (l.dtype(), r.class()) {
476 (l_dtype, r_dtype) if l_dtype > r_dtype => l.gt_const(r.into_type(l_dtype)),
477 (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).gt_const(r),
478 (l, r) => unreachable!("{} greater than {}", l, r),
479 },
480 }
481 }
482
483 pub fn gte(&self, other: &Array) -> Array {
485 use Array::*;
486 match (self, other) {
487 (Bool(l), Bool(r)) => Bool(l.gte(r.deref())),
488 (C32(l), C32(r)) => Bool(l.gte(r.deref())),
489 (C64(l), C64(r)) => Bool(l.gte(r.deref())),
490 (F32(l), F32(r)) => Bool(l.gte(r.deref())),
491 (F64(l), F64(r)) => Bool(l.gte(r.deref())),
492 (I16(l), I16(r)) => Bool(l.gte(r.deref())),
493 (I32(l), I32(r)) => Bool(l.gte(r.deref())),
494 (I64(l), I64(r)) => Bool(l.gte(r.deref())),
495 (U8(l), U8(r)) => Bool(l.gte(r.deref())),
496 (U16(l), U16(r)) => Bool(l.gte(r.deref())),
497 (U32(l), U32(r)) => Bool(l.gte(r.deref())),
498 (U64(l), U64(r)) => Bool(l.gte(r.deref())),
499 (l, r) => match (l.dtype(), r.dtype()) {
500 (l_dtype, r_dtype) if l_dtype > r_dtype => l.gte(&r.cast_into(l_dtype)),
501 (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).gte(r),
502 (l, r) => unreachable!("{} greater than or equal to {}", l, r),
503 },
504 }
505 }
506
507 pub fn gte_const(&self, other: Number) -> Array {
509 use number_general::Complex;
510 match (self, other) {
511 (Self::Bool(l), Number::Bool(r)) => Self::Bool(l.gte(&bool::from(r))),
512
513 (Self::C32(l), Number::Complex(Complex::C32(r))) => Self::Bool(l.gte(&r)),
514 (Self::C64(l), Number::Complex(Complex::C64(r))) => Self::Bool(l.gte(&r)),
515
516 (Self::F32(l), Number::Float(Float::F32(r))) => Self::Bool(l.gte(&r)),
517 (Self::F64(l), Number::Float(Float::F64(r))) => Self::Bool(l.gte(&r)),
518
519 (Self::I16(l), Number::Int(Int::I16(r))) => Self::Bool(l.gte(&r)),
520 (Self::I32(l), Number::Int(Int::I32(r))) => Self::Bool(l.gte(&r)),
521 (Self::I64(l), Number::Int(Int::I64(r))) => Self::Bool(l.gte(&r)),
522
523 (Self::U8(l), Number::UInt(UInt::U8(r))) => Self::Bool(l.gte(&r)),
524 (Self::U16(l), Number::UInt(UInt::U16(r))) => Self::Bool(l.gte(&r)),
525 (Self::U32(l), Number::UInt(UInt::U32(r))) => Self::Bool(l.gte(&r)),
526 (Self::U64(l), Number::UInt(UInt::U64(r))) => Self::Bool(l.gte(&r)),
527
528 (l, r) => match (l.dtype(), r.class()) {
529 (l_dtype, r_dtype) if l_dtype > r_dtype => l.gte_const(r.into_type(l_dtype)),
530 (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).gte_const(r),
531 (l, r) => unreachable!("{} greater than or equal to {}", l, r),
532 },
533 }
534 }
535
536 pub fn is_infinite(&self) -> Array {
538 fn is_infinite<T>(this: &ArrayExt<T>) -> Array
539 where
540 T: af::HasAfEnum + Default,
541 ArrayExt<T>: ArrayInstanceUnreal,
542 {
543 this.is_infinite().into()
544 }
545
546 dispatch!(self, is_infinite)
547 }
548
549 pub fn is_nan(&self) -> Array {
551 fn is_nan<T>(this: &ArrayExt<T>) -> Array
552 where
553 T: af::HasAfEnum + Default,
554 ArrayExt<T>: ArrayInstanceUnreal,
555 {
556 this.is_nan().into()
557 }
558
559 dispatch!(self, is_nan)
560 }
561
562 pub fn ln(&self) -> Array {
564 fn ln<T>(this: &ArrayExt<T>) -> Array
565 where
566 T: af::HasAfEnum + Default,
567 Array: From<ArrayExt<T::UnaryOutType>>,
568 {
569 this.ln().into()
570 }
571
572 dispatch!(self, ln)
573 }
574
575 pub fn log(&self, base: &Array) -> Array {
577 use Array::*;
578 match (self, base) {
579 (Bool(l), Bool(r)) => l.log(r).into(),
580 (C32(l), C32(r)) => l.log(r).into(),
581 (C64(l), C64(r)) => l.log(r).into(),
582 (F32(l), F32(r)) => l.log(r).into(),
583 (F64(l), F64(r)) => l.log(r).into(),
584 (I16(l), I16(r)) => l.log(r).into(),
585 (I32(l), I32(r)) => l.log(r).into(),
586 (I64(l), I64(r)) => l.log(r).into(),
587 (U8(l), U8(r)) => l.log(r).into(),
588 (U16(l), U16(r)) => l.log(r).into(),
589 (U32(l), U32(r)) => l.log(r).into(),
590 (U64(l), U64(r)) => l.log(r).into(),
591 (l, r) => match (l.dtype(), r.dtype()) {
592 (l_dtype, r_dtype) if l_dtype > r_dtype => l.log(&r.cast_into(l_dtype)),
593 (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).log(&r),
594 (l, r) => unreachable!("{} log {}", l, r),
595 },
596 }
597 }
598
599 pub fn log_const(&self, base: Number) -> Array {
601 (&self.ln()) / base.ln()
602 }
603
604 pub fn lt(&self, other: &Array) -> Array {
606 use Array::*;
607 match (self, other) {
608 (Bool(l), Bool(r)) => Bool(l.lt(r.deref())),
609 (C32(l), C32(r)) => Bool(l.lt(r.deref())),
610 (C64(l), C64(r)) => Bool(l.lt(r.deref())),
611 (F32(l), F32(r)) => Bool(l.lt(r.deref())),
612 (F64(l), F64(r)) => Bool(l.lt(r.deref())),
613 (I16(l), I16(r)) => Bool(l.lt(r.deref())),
614 (I32(l), I32(r)) => Bool(l.lt(r.deref())),
615 (I64(l), I64(r)) => Bool(l.lt(r.deref())),
616 (U8(l), U8(r)) => Bool(l.lt(r.deref())),
617 (U16(l), U16(r)) => Bool(l.lt(r.deref())),
618 (U32(l), U32(r)) => Bool(l.lt(r.deref())),
619 (U64(l), U64(r)) => Bool(l.lt(r.deref())),
620 (l, r) => match (l.dtype(), r.dtype()) {
621 (l_dtype, r_dtype) if l_dtype > r_dtype => l.lt(&r.cast_into(l_dtype)),
622 (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).lt(r),
623 (l, r) => unreachable!("{} less than {}", l, r),
624 },
625 }
626 }
627
628 pub fn lt_const(&self, other: Number) -> Array {
630 use number_general::Complex;
631 match (self, other) {
632 (Self::Bool(l), Number::Bool(r)) => Self::Bool(l.lt(&bool::from(r))),
633
634 (Self::C32(l), Number::Complex(Complex::C32(r))) => Self::Bool(l.lt(&r)),
635 (Self::C64(l), Number::Complex(Complex::C64(r))) => Self::Bool(l.lt(&r)),
636
637 (Self::F32(l), Number::Float(Float::F32(r))) => Self::Bool(l.lt(&r)),
638 (Self::F64(l), Number::Float(Float::F64(r))) => Self::Bool(l.lt(&r)),
639
640 (Self::I16(l), Number::Int(Int::I16(r))) => Self::Bool(l.lt(&r)),
641 (Self::I32(l), Number::Int(Int::I32(r))) => Self::Bool(l.lt(&r)),
642 (Self::I64(l), Number::Int(Int::I64(r))) => Self::Bool(l.lt(&r)),
643
644 (Self::U8(l), Number::UInt(UInt::U8(r))) => Self::Bool(l.lt(&r)),
645 (Self::U16(l), Number::UInt(UInt::U16(r))) => Self::Bool(l.lt(&r)),
646 (Self::U32(l), Number::UInt(UInt::U32(r))) => Self::Bool(l.lt(&r)),
647 (Self::U64(l), Number::UInt(UInt::U64(r))) => Self::Bool(l.lt(&r)),
648
649 (l, r) => match (l.dtype(), r.class()) {
650 (l_dtype, r_dtype) if l_dtype > r_dtype => l.lt_const(r.into_type(l_dtype)),
651 (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).lt_const(r),
652 (l, r) => unreachable!("{} less than {}", l, r),
653 },
654 }
655 }
656
657 pub fn lte(&self, other: &Array) -> Array {
659 use Array::*;
660 match (self, other) {
661 (Bool(l), Bool(r)) => Bool(l.lte(r.deref())),
662 (C32(l), C32(r)) => Bool(l.lte(r.deref())),
663 (C64(l), C64(r)) => Bool(l.lte(r.deref())),
664 (F32(l), F32(r)) => Bool(l.lte(r.deref())),
665 (F64(l), F64(r)) => Bool(l.lte(r.deref())),
666 (I16(l), I16(r)) => Bool(l.lte(r.deref())),
667 (I32(l), I32(r)) => Bool(l.lte(r.deref())),
668 (I64(l), I64(r)) => Bool(l.lte(r.deref())),
669 (U8(l), U8(r)) => Bool(l.lte(r.deref())),
670 (U16(l), U16(r)) => Bool(l.lte(r.deref())),
671 (U32(l), U32(r)) => Bool(l.lte(r.deref())),
672 (U64(l), U64(r)) => Bool(l.lte(r.deref())),
673 (l, r) => match (l.dtype(), r.dtype()) {
674 (l_dtype, r_dtype) if l_dtype > r_dtype => l.lte(&r.cast_into(l_dtype)),
675 (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).lte(r),
676 (l, r) => unreachable!("{} less than or equal to {}", l, r),
677 },
678 }
679 }
680
681 pub fn lte_const(&self, other: Number) -> Array {
683 use number_general::Complex;
684 match (self, other) {
685 (Self::Bool(l), Number::Bool(r)) => Self::Bool(l.lte(&bool::from(r))),
686
687 (Self::C32(l), Number::Complex(Complex::C32(r))) => Self::Bool(l.lte(&r)),
688 (Self::C64(l), Number::Complex(Complex::C64(r))) => Self::Bool(l.lte(&r)),
689
690 (Self::F32(l), Number::Float(Float::F32(r))) => Self::Bool(l.lte(&r)),
691 (Self::F64(l), Number::Float(Float::F64(r))) => Self::Bool(l.lte(&r)),
692
693 (Self::I16(l), Number::Int(Int::I16(r))) => Self::Bool(l.lte(&r)),
694 (Self::I32(l), Number::Int(Int::I32(r))) => Self::Bool(l.lte(&r)),
695 (Self::I64(l), Number::Int(Int::I64(r))) => Self::Bool(l.lte(&r)),
696
697 (Self::U8(l), Number::UInt(UInt::U8(r))) => Self::Bool(l.lte(&r)),
698 (Self::U16(l), Number::UInt(UInt::U16(r))) => Self::Bool(l.lte(&r)),
699 (Self::U32(l), Number::UInt(UInt::U32(r))) => Self::Bool(l.lte(&r)),
700 (Self::U64(l), Number::UInt(UInt::U64(r))) => Self::Bool(l.lte(&r)),
701
702 (l, r) => match (l.dtype(), r.class()) {
703 (l_dtype, r_dtype) if l_dtype > r_dtype => l.lte_const(r.into_type(l_dtype)),
704 (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).lte_const(r),
705 (l, r) => unreachable!("{} less than or equal to {}", l, r),
706 },
707 }
708 }
709
710 pub fn ne(&self, other: &Array) -> Array {
712 use Array::*;
713 match (self, other) {
714 (Bool(l), Bool(r)) => Bool(l.ne(r.deref())),
715 (C32(l), C32(r)) => Bool(l.ne(r.deref())),
716 (C64(l), C64(r)) => Bool(l.ne(r.deref())),
717 (F32(l), F32(r)) => Bool(l.ne(r.deref())),
718 (F64(l), F64(r)) => Bool(l.ne(r.deref())),
719 (I16(l), I16(r)) => Bool(l.ne(r.deref())),
720 (I32(l), I32(r)) => Bool(l.ne(r.deref())),
721 (I64(l), I64(r)) => Bool(l.ne(r.deref())),
722 (U8(l), U8(r)) => Bool(l.ne(r.deref())),
723 (U16(l), U16(r)) => Bool(l.ne(r.deref())),
724 (U32(l), U32(r)) => Bool(l.ne(r.deref())),
725 (U64(l), U64(r)) => Bool(l.ne(r.deref())),
726 (l, r) => match (l.dtype(), r.dtype()) {
727 (l_dtype, r_dtype) if l_dtype > r_dtype => l.ne(&r.cast_into(l_dtype)),
728 (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).ne(r),
729 (l, r) => unreachable!("{} not equal to {}", l, r),
730 },
731 }
732 }
733
734 pub fn ne_const(&self, other: Number) -> Array {
736 use number_general::Complex;
737 match (self, other) {
738 (Self::Bool(l), Number::Bool(r)) => Self::Bool(l.ne(&bool::from(r))),
739
740 (Self::C32(l), Number::Complex(Complex::C32(r))) => Self::Bool(l.ne(&r)),
741 (Self::C64(l), Number::Complex(Complex::C64(r))) => Self::Bool(l.ne(&r)),
742
743 (Self::F32(l), Number::Float(Float::F32(r))) => Self::Bool(l.ne(&r)),
744 (Self::F64(l), Number::Float(Float::F64(r))) => Self::Bool(l.ne(&r)),
745
746 (Self::I16(l), Number::Int(Int::I16(r))) => Self::Bool(l.ne(&r)),
747 (Self::I32(l), Number::Int(Int::I32(r))) => Self::Bool(l.ne(&r)),
748 (Self::I64(l), Number::Int(Int::I64(r))) => Self::Bool(l.ne(&r)),
749
750 (Self::U8(l), Number::UInt(UInt::U8(r))) => Self::Bool(l.ne(&r)),
751 (Self::U16(l), Number::UInt(UInt::U16(r))) => Self::Bool(l.ne(&r)),
752 (Self::U32(l), Number::UInt(UInt::U32(r))) => Self::Bool(l.ne(&r)),
753 (Self::U64(l), Number::UInt(UInt::U64(r))) => Self::Bool(l.ne(&r)),
754
755 (l, r) => match (l.dtype(), r.class()) {
756 (l_dtype, r_dtype) if l_dtype > r_dtype => l.ne_const(r.into_type(l_dtype)),
757 (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).ne_const(r),
758 (l, r) => unreachable!("{} not equal to {}", l, r),
759 },
760 }
761 }
762
763 pub fn not(&self) -> Array {
765 let this: ArrayExt<bool> = self.type_cast();
766 Array::Bool(this.not())
767 }
768
769 pub fn or(&self, other: &Array) -> Array {
771 let this: ArrayExt<bool> = self.type_cast();
772 let that: ArrayExt<bool> = other.type_cast();
773 Array::Bool(this.or(&that))
774 }
775
776 pub fn or_const(&self, other: Number) -> Array {
778 let this: ArrayExt<bool> = self.type_cast();
779 let that: ArrayExt<bool> = ArrayExt::from(&[other.cast_into()][..]);
780 Array::Bool(this.or(&that))
781 }
782
783 pub fn reduce_max(&self, stride: u64) -> Result<Array> {
785 if self.len() as u64 % stride != 0 {
786 return Err(error(format!(
787 "cannot reduce an Array of length {} with stride {}",
788 self.len(),
789 stride
790 )));
791 }
792
793 fn reduce_block_dispatch<T: af::HasAfEnum>(block: &ArrayExt<T>, stride: u64) -> Array
794 where
795 Array: From<ArrayExt<T::InType>>,
796 {
797 reduce_block(block, stride, &mut |block| af::max(&block, 0).into()).into()
798 }
799
800 Ok(reduce!(self, reduce_block_dispatch, stride))
801 }
802
803 pub fn reduce_min(&self, stride: u64) -> Result<Array> {
805 if self.len() as u64 % stride != 0 {
806 return Err(error(format!(
807 "cannot reduce an Array of length {} with stride {}",
808 self.len(),
809 stride
810 )));
811 }
812
813 fn reduce_block_dispatch<T: af::HasAfEnum>(block: &ArrayExt<T>, stride: u64) -> Array
814 where
815 Array: From<ArrayExt<T::InType>>,
816 {
817 reduce_block(block, stride, &mut |block| af::min(&block, 0).into()).into()
818 }
819
820 Ok(reduce!(self, reduce_block_dispatch, stride))
821 }
822
823 pub fn reduce_product(&self, stride: u64) -> Result<Array> {
825 if self.len() as u64 % stride != 0 {
826 return Err(error(format!(
827 "cannot reduce an Array of length {} with stride {}",
828 self.len(),
829 stride
830 )));
831 }
832
833 fn reduce_block_dispatch<T: af::HasAfEnum>(block: &ArrayExt<T>, stride: u64) -> Array
834 where
835 Array: From<ArrayExt<T::ProductOutType>>,
836 {
837 reduce_block(block, stride, &mut |block| af::product(&block, 0).into()).into()
838 }
839
840 Ok(reduce!(self, reduce_block_dispatch, stride))
841 }
842
843 pub fn reduce_sum(&self, stride: u64) -> Result<Array> {
845 if self.len() as u64 % stride != 0 {
846 return Err(error(format!(
847 "cannot reduce an Array of length {} with stride {}",
848 self.len(),
849 stride
850 )));
851 }
852
853 fn reduce_block_dispatch<T: af::HasAfEnum>(block: &ArrayExt<T>, stride: u64) -> Array
854 where
855 Array: From<ArrayExt<T::AggregateOutType>>,
856 {
857 reduce_block(block, stride, &mut |block| af::sum(&block, 0).into()).into()
858 }
859
860 Ok(reduce!(self, reduce_block_dispatch, stride))
861 }
862
863 pub fn max(&self) -> Number {
865 fn max<T>(this: &ArrayExt<T>) -> Number
866 where
867 T: af::HasAfEnum + Default,
868 T::AggregateOutType: number_general::DType,
869 T::ProductOutType: number_general::DType,
870 ArrayExt<T>: ArrayInstanceMinMax<T>,
871 Number: From<T>,
872 {
873 this.max().into()
874 }
875
876 dispatch!(self, max)
877 }
878
879 pub fn min(&self) -> Number {
881 fn min<T>(this: &ArrayExt<T>) -> Number
882 where
883 T: af::HasAfEnum + Default,
884 T::AggregateOutType: number_general::DType,
885 T::ProductOutType: number_general::DType,
886 ArrayExt<T>: ArrayInstanceMinMax<T>,
887 Number: From<T>,
888 {
889 this.min().into()
890 }
891
892 dispatch!(self, min)
893 }
894
895 pub fn product(&self) -> Number {
897 fn product<T>(this: &ArrayExt<T>) -> Number
898 where
899 T: af::HasAfEnum + Default,
900 T::AggregateOutType: number_general::DType,
901 T::ProductOutType: number_general::DType,
902 ArrayExt<T>: ArrayInstanceProduct<T>,
903 Number: From<T::ProductOutType>,
904 {
905 this.product().into()
906 }
907
908 dispatch!(self, product)
909 }
910
911 pub fn sum(&self) -> Number {
913 fn sum<T>(this: &ArrayExt<T>) -> Number
914 where
915 T: af::HasAfEnum + Default,
916 T::AggregateOutType: number_general::DType,
917 T::ProductOutType: number_general::DType,
918 ArrayExt<T>: ArrayInstanceSum<T>,
919 Number: From<T::AggregateOutType>,
920 {
921 this.sum().into()
922 }
923
924 dispatch!(self, sum)
925 }
926
927 pub fn len(&self) -> usize {
929 dispatch!(self, ArrayExt::len)
930 }
931
932 pub fn get_value(&self, index: usize) -> Number {
934 debug_assert!(index < self.len());
935
936 use number_general::Complex;
937 use Array::*;
938 match self {
939 Bool(b) => b.get_value(index).into(),
940 C32(c) => Complex::from(c.get_value(index)).into(),
941 C64(c) => Complex::from(c.get_value(index)).into(),
942 F32(f) => Float::from(f.get_value(index)).into(),
943 F64(f) => Float::from(f.get_value(index)).into(),
944 I16(i) => Int::from(i.get_value(index)).into(),
945 I32(i) => Int::from(i.get_value(index)).into(),
946 I64(i) => Int::from(i.get_value(index)).into(),
947 U8(u) => UInt::from(u.get_value(index)).into(),
948 U16(u) => UInt::from(u.get_value(index)).into(),
949 U32(u) => UInt::from(u.get_value(index)).into(),
950 U64(u) => UInt::from(u.get_value(index)).into(),
951 }
952 }
953
954 pub fn get(&self, index: &ArrayExt<u64>) -> Self {
956 let mut indexer = af::Indexer::default();
957 indexer.set_index(index.deref(), 0, None);
958 self.get_at(indexer)
959 }
960
961 fn get_at(&self, index: af::Indexer) -> Self {
962 use Array::*;
963 match self {
964 Bool(b) => Bool(b.get(index)),
965 C32(c) => C32(c.get(index)),
966 C64(c) => C64(c.get(index)),
967 F32(f) => F32(f.get(index)),
968 F64(f) => F64(f.get(index)),
969 I16(i) => I16(i.get(index)),
970 I32(i) => I32(i.get(index)),
971 I64(i) => I64(i.get(index)),
972 U8(i) => U8(i.get(index)),
973 U16(i) => U16(i.get(index)),
974 U32(i) => U32(i.get(index)),
975 U64(i) => U64(i.get(index)),
976 }
977 }
978
979 pub fn pow(&self, other: &Self) -> Self {
981 use Array::*;
983 match (self, other) {
984 (C32(l), C32(r)) => C32(l.pow(r.deref())),
985 (C64(l), C64(r)) => C64(l.pow(r.deref())),
986 (F32(l), F32(r)) => F32(l.pow(r.deref())),
987 (F64(l), F64(r)) => F64(l.pow(r.deref())),
988 (l, r) => match (l.dtype(), r.dtype()) {
989 (l_dtype, r_dtype) if l_dtype > r_dtype => l.pow(&r.cast_into(l_dtype)),
990 (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).pow(r),
991 _ => Self::F64(l.type_cast()).pow(r),
992 },
993 }
994 }
995
996 pub fn pow_const(&self, other: Number) -> Self {
998 use number_general::Complex;
1000 match (self, other) {
1001 (Self::C32(l), Number::Complex(Complex::C32(r))) => Self::C32(l.pow(&r)),
1002 (Self::C64(l), Number::Complex(Complex::C64(r))) => Self::C64(l.pow(&r)),
1003 (Self::F32(l), Number::Float(Float::F32(r))) => Self::F32(l.pow(&r)),
1004 (Self::F64(l), Number::Float(Float::F64(r))) => Self::F64(l.pow(&r)),
1005 (l, r) => match (l.dtype(), r.class()) {
1006 (l_dtype, r_dtype) if l_dtype > r_dtype => l.pow_const(r.into_type(l_dtype)),
1007 (l_dtype, r_dtype) if l_dtype < r_dtype => l.cast_into(r_dtype).pow_const(r),
1008 _ => Self::F64(l.type_cast()).pow_const(r),
1009 },
1010 }
1011 }
1012
1013 pub fn round(&self) -> Self {
1015 fn round<T: af::HasAfEnum>(x: &ArrayExt<T>) -> Array
1016 where
1017 Array: From<ArrayExt<<ArrayExt<T> as ArrayInstanceRound>::Round>>,
1018 {
1019 x.round().into()
1020 }
1021
1022 dispatch!(self, round)
1023 }
1024
1025 pub fn set(&mut self, index: &ArrayExt<u64>, other: &Array) -> Result<()> {
1027 let mut indexer = af::Indexer::default();
1028 indexer.set_index(index.deref(), 0, None);
1029 self.set_at(indexer, other)
1030 }
1031
1032 pub fn set_value(&mut self, offset: usize, value: Number) -> Result<()> {
1034 use Array::*;
1035 match self {
1036 Bool(b) => {
1037 let value: Boolean = value.cast_into();
1038 b.set_at(offset, value.cast_into());
1039 }
1040 C32(c) => {
1041 let value: Complex<f32> = value.cast_into();
1042 c.set_at(offset, value.cast_into())
1043 }
1044 C64(c) => {
1045 let value: Complex<f64> = value.cast_into();
1046 c.set_at(offset, value.cast_into())
1047 }
1048 F32(f) => {
1049 let value: Float = value.cast_into();
1050 f.set_at(offset, value.cast_into())
1051 }
1052 F64(f) => {
1053 let value: Float = value.cast_into();
1054 f.set_at(offset, value.cast_into())
1055 }
1056 I16(i) => {
1057 let value: Int = value.cast_into();
1058 i.set_at(offset, value.cast_into())
1059 }
1060 I32(i) => {
1061 let value: Int = value.cast_into();
1062 i.set_at(offset, value.cast_into())
1063 }
1064 I64(i) => {
1065 let value: Int = value.cast_into();
1066 i.set_at(offset, value.cast_into())
1067 }
1068 U8(u) => {
1069 let value: UInt = value.cast_into();
1070 u.set_at(offset, value.cast_into())
1071 }
1072 U16(u) => {
1073 let value: UInt = value.cast_into();
1074 u.set_at(offset, value.cast_into())
1075 }
1076 U32(u) => {
1077 let value: UInt = value.cast_into();
1078 u.set_at(offset, value.cast_into())
1079 }
1080 U64(u) => {
1081 let value: UInt = value.cast_into();
1082 u.set_at(offset, value.cast_into())
1083 }
1084 }
1085
1086 Ok(())
1087 }
1088
1089 fn set_at(&mut self, index: af::Indexer, value: &Array) -> Result<()> {
1090 use Array::*;
1091 match self {
1092 Bool(l) => l.set(&index, &value.type_cast()),
1093 C32(l) => l.set(&index, &value.type_cast()),
1094 C64(l) => l.set(&index, &value.type_cast()),
1095 F32(l) => l.set(&index, &value.type_cast()),
1096 F64(l) => l.set(&index, &value.type_cast()),
1097 I16(l) => l.set(&index, &value.type_cast()),
1098 I32(l) => l.set(&index, &value.type_cast()),
1099 I64(l) => l.set(&index, &value.type_cast()),
1100 U8(l) => l.set(&index, &value.type_cast()),
1101 U16(l) => l.set(&index, &value.type_cast()),
1102 U32(l) => l.set(&index, &value.type_cast()),
1103 U64(l) => l.set(&index, &value.type_cast()),
1104 }
1105
1106 Ok(())
1107 }
1108
1109 pub fn slice(&self, start: usize, end: usize) -> Result<Self> {
1111 if start > self.len() {
1112 return Err(error(format!(
1113 "invalid start index for array slice: {}",
1114 start
1115 )));
1116 }
1117
1118 if end > self.len() {
1119 return Err(error(format!(
1120 "invalid start index for array slice: {}",
1121 end
1122 )));
1123 }
1124
1125 use Array::*;
1126 let slice = match self {
1127 Bool(b) => b.slice(start, end).into(),
1128 C32(c) => c.slice(start, end).into(),
1129 C64(c) => c.slice(start, end).into(),
1130 F32(f) => f.slice(start, end).into(),
1131 F64(f) => f.slice(start, end).into(),
1132 I16(i) => i.slice(start, end).into(),
1133 I32(i) => i.slice(start, end).into(),
1134 I64(i) => i.slice(start, end).into(),
1135 U8(u) => u.slice(start, end).into(),
1136 U16(u) => u.slice(start, end).into(),
1137 U32(u) => u.slice(start, end).into(),
1138 U64(u) => u.slice(start, end).into(),
1139 };
1140
1141 Ok(slice)
1142 }
1143
1144 pub fn argsort(&self, ascending: bool) -> Result<(Self, ArrayExt<u64>)> {
1146 macro_rules! argsort {
1147 ($arr:expr) => {{
1148 let (sorted, indices) = $arr.sort_index(ascending);
1149 (sorted.into(), indices.type_cast())
1150 }};
1151 }
1152
1153 use Array::*;
1154 let (sorted, indices) = match self {
1155 Bool(b) => argsort!(b),
1156 F32(f) => argsort!(f),
1157 F64(f) => argsort!(f),
1158 I16(i) => argsort!(i),
1159 I32(i) => argsort!(i),
1160 I64(i) => argsort!(i),
1161 U8(u) => argsort!(u),
1162 U16(u) => argsort!(u),
1163 U32(u) => argsort!(u),
1164 U64(u) => argsort!(u),
1165 other => {
1166 return Err(error(format!(
1167 "{} does not support ordering",
1168 other.dtype()
1169 )))
1170 }
1171 };
1172
1173 Ok((sorted, indices))
1174 }
1175
1176 pub fn sort(&mut self, ascending: bool) -> Result<()> {
1178 use Array::*;
1179 match self {
1180 Bool(b) => b.sort(ascending),
1181 F32(f) => f.sort(ascending),
1182 F64(f) => f.sort(ascending),
1183 I16(i) => i.sort(ascending),
1184 I32(i) => i.sort(ascending),
1185 I64(i) => i.sort(ascending),
1186 U8(u) => u.sort(ascending),
1187 U16(u) => u.sort(ascending),
1188 U32(u) => u.sort(ascending),
1189 U64(u) => u.sort(ascending),
1190 other => {
1191 return Err(error(format!(
1192 "{} does not support ordering",
1193 other.dtype()
1194 )))
1195 }
1196 }
1197
1198 Ok(())
1199 }
1200
1201 pub fn split(&self, at: usize) -> Result<(Array, Array)> {
1203 if at > self.len() {
1204 return Err(error(format!(
1205 "Invalid pivot for Array of length {}",
1206 self.len()
1207 )));
1208 }
1209
1210 use Array::*;
1211 match self {
1212 Bool(u) => {
1213 let (l, r) = u.split(at);
1214 Ok((Bool(l), Bool(r)))
1215 }
1216 C32(u) => {
1217 let (l, r) = u.split(at);
1218 Ok((C32(l), C32(r)))
1219 }
1220 C64(u) => {
1221 let (l, r) = u.split(at);
1222 Ok((C64(l), C64(r)))
1223 }
1224 F32(u) => {
1225 let (l, r) = u.split(at);
1226 Ok((F32(l), F32(r)))
1227 }
1228 F64(u) => {
1229 let (l, r) = u.split(at);
1230 Ok((F64(l), F64(r)))
1231 }
1232 I16(u) => {
1233 let (l, r) = u.split(at);
1234 Ok((I16(l), I16(r)))
1235 }
1236 I32(u) => {
1237 let (l, r) = u.split(at);
1238 Ok((I32(l), I32(r)))
1239 }
1240 I64(u) => {
1241 let (l, r) = u.split(at);
1242 Ok((I64(l), I64(r)))
1243 }
1244 U8(u) => {
1245 let (l, r) = u.split(at);
1246 Ok((U8(l), U8(r)))
1247 }
1248 U16(u) => {
1249 let (l, r) = u.split(at);
1250 Ok((U16(l), U16(r)))
1251 }
1252 U32(u) => {
1253 let (l, r) = u.split(at);
1254 Ok((U32(l), U32(r)))
1255 }
1256 U64(u) => {
1257 let (l, r) = u.split(at);
1258 Ok((U64(l), U64(r)))
1259 }
1260 }
1261 }
1262
1263 trig! {sin}
1266 trig! {asin}
1267 trig! {sinh}
1268 trig! {asinh}
1269 trig! {cos}
1270 trig! {acos}
1271 trig! {cosh}
1272 trig! {acosh}
1273 trig! {tan}
1274 trig! {atan}
1275 trig! {tanh}
1276 trig! {atanh}
1277
1278 pub fn xor(&self, other: &Array) -> Array {
1280 let this: ArrayExt<bool> = self.type_cast();
1281 let that: ArrayExt<bool> = other.type_cast();
1282 Array::Bool(this.xor(&that))
1283 }
1284
1285 pub fn xor_const(&self, other: Number) -> Array {
1287 let this: ArrayExt<bool> = self.type_cast();
1288 let that: ArrayExt<bool> = ArrayExt::from(&[other.cast_into()][..]);
1289 Array::Bool(this.xor(&that))
1290 }
1291}
1292
1293impl PartialEq for Array {
1294 fn eq(&self, other: &Array) -> bool {
1295 if self.len() != other.len() {
1296 return false;
1297 } else {
1298 Array::eq(self, other).all()
1299 }
1300 }
1301}
1302
1303impl Add for &Array {
1304 type Output = Array;
1305
1306 fn add(self, other: &Array) -> Self::Output {
1307 use Array::*;
1308 match (self, other) {
1309 (Bool(l), Bool(r)) => Bool(l + r),
1310 (C32(l), C32(r)) => C32(l + r),
1311 (C64(l), C64(r)) => C64(l + r),
1312 (F32(l), F32(r)) => F32(l + r),
1313 (F64(l), F64(r)) => F64(l + r),
1314 (I16(l), I16(r)) => I16(l + r),
1315 (I32(l), I32(r)) => I32(l + r),
1316 (I64(l), I64(r)) => I64(l + r),
1317 (U8(l), U8(r)) => U8(l + r),
1318 (U16(l), U16(r)) => U16(l + r),
1319 (U32(l), U32(r)) => U32(l + r),
1320 (U64(l), U64(r)) => U64(l + r),
1321 (l, r) => match (l.dtype(), r.dtype()) {
1322 (l_dtype, r_dtype) if l_dtype > r_dtype => l + &r.cast_into(l_dtype),
1323 (l_dtype, r_dtype) if l_dtype < r_dtype => &l.cast_into(r_dtype) + r,
1324 (l, r) => unreachable!("add {}, {}", l, r),
1325 },
1326 }
1327 }
1328}
1329
1330impl Add<Number> for &Array {
1331 type Output = Array;
1332
1333 fn add(self, rhs: Number) -> Self::Output {
1334 use number_general::Complex;
1335 match (self, rhs) {
1336 (Array::Bool(l), Number::Bool(r)) => Array::Bool((l.deref() + bool::from(r)).into()),
1337
1338 (Array::F32(l), Number::Float(Float::F32(r))) => Array::F32((l.deref() + r).into()),
1339 (Array::F64(l), Number::Float(Float::F32(r))) => Array::F64((l.deref() + r).into()),
1340 (Array::F64(l), Number::Float(Float::F64(r))) => Array::F64((l.deref() + r).into()),
1341
1342 (Array::C32(l), Number::Complex(Complex::C32(r))) => Array::C32((l.deref() + r).into()),
1343 (Array::C64(l), Number::Complex(Complex::C64(r))) => Array::C64((l.deref() + r).into()),
1344
1345 (Array::I16(l), Number::Int(Int::I16(r))) => Array::I16((l.deref() + r).into()),
1346 (Array::I32(l), Number::Int(Int::I32(r))) => Array::I32((l.deref() + r).into()),
1347 (Array::I64(l), Number::Int(Int::I64(r))) => Array::I64((l.deref() + r).into()),
1348
1349 (Array::U8(l), Number::UInt(UInt::U8(r))) => Array::U8((l.deref() + r).into()),
1350 (Array::U16(l), Number::UInt(UInt::U16(r))) => Array::U16((l.deref() + r).into()),
1351 (Array::U32(l), Number::UInt(UInt::U32(r))) => Array::U32((l.deref() + r).into()),
1352 (Array::U64(l), Number::UInt(UInt::U64(r))) => Array::U64((l.deref() + r).into()),
1353
1354 (l, r) => match (l.dtype(), r.class()) {
1355 (l_dtype, r_dtype) if l_dtype > r_dtype => l + r.into_type(l_dtype),
1356 (l_dtype, r_dtype) if l_dtype < r_dtype => &l.cast_into(r_dtype) + r,
1357 (l, r) => unreachable!("add {}, {}", l, r),
1358 },
1359 }
1360 }
1361}
1362
1363impl AddAssign<&Array> for Array {
1364 fn add_assign(&mut self, other: &Array) {
1365 let sum = &*self + other;
1366 *self = sum;
1367 }
1368}
1369
1370impl AddAssign<Number> for Array {
1371 fn add_assign(&mut self, rhs: Number) {
1372 *self = &*self + rhs;
1373 }
1374}
1375
1376impl Sub for &Array {
1377 type Output = Array;
1378
1379 fn sub(self, other: &Array) -> Self::Output {
1380 use Array::*;
1381 match (self, other) {
1382 (Bool(l), Bool(r)) => Bool(l - r),
1383 (C32(l), C32(r)) => C32(l - r),
1384 (C64(l), C64(r)) => C64(l - r),
1385 (F32(l), F32(r)) => F32(l - r),
1386 (F64(l), F64(r)) => F64(l - r),
1387 (I16(l), I16(r)) => I16(l - r),
1388 (I32(l), I32(r)) => I32(l - r),
1389 (I64(l), I64(r)) => I64(l - r),
1390 (U8(l), U8(r)) => U8(l - r),
1391 (U16(l), U16(r)) => U16(l - r),
1392 (U32(l), U32(r)) => U32(l - r),
1393 (U64(l), U64(r)) => U64(l - r),
1394 (l, r) => match (l.dtype(), r.dtype()) {
1395 (l_dtype, r_dtype) if l_dtype > r_dtype => l - &r.cast_into(l_dtype),
1396 (l_dtype, r_dtype) if l_dtype < r_dtype => &l.cast_into(r_dtype) - r,
1397 (l, r) => unreachable!("subtract {}, {}", l, r),
1398 },
1399 }
1400 }
1401}
1402
1403impl Sub<Number> for &Array {
1404 type Output = Array;
1405
1406 fn sub(self, rhs: Number) -> Self::Output {
1407 use number_general::Complex;
1408 match (self, rhs) {
1409 (Array::Bool(l), Number::Bool(r)) => Array::Bool((l.deref() - bool::from(r)).into()),
1410
1411 (Array::F32(l), Number::Float(Float::F32(r))) => Array::F32((l.deref() - r).into()),
1412 (Array::F64(l), Number::Float(Float::F64(r))) => Array::F64((l.deref() - r).into()),
1413
1414 (Array::C32(l), Number::Complex(Complex::C32(r))) => Array::C32((l.deref() - r).into()),
1415 (Array::C64(l), Number::Complex(Complex::C64(r))) => Array::C64((l.deref() - r).into()),
1416
1417 (Array::I16(l), Number::Int(Int::I16(r))) => Array::I16((l.deref() - r).into()),
1418 (Array::I32(l), Number::Int(Int::I32(r))) => Array::I32((l.deref() - r).into()),
1419 (Array::I64(l), Number::Int(Int::I64(r))) => Array::I64((l.deref() - r).into()),
1420
1421 (Array::U8(l), Number::UInt(UInt::U8(r))) => Array::U8((l.deref() - r).into()),
1422 (Array::U16(l), Number::UInt(UInt::U16(r))) => Array::U16((l.deref() - r).into()),
1423 (Array::U32(l), Number::UInt(UInt::U32(r))) => Array::U32((l.deref() - r).into()),
1424 (Array::U64(l), Number::UInt(UInt::U64(r))) => Array::U64((l.deref() - r).into()),
1425
1426 (l, r) => match (l.dtype(), r.class()) {
1427 (l_dtype, r_dtype) if l_dtype > r_dtype => l - r.into_type(l_dtype),
1428 (l_dtype, r_dtype) if l_dtype < r_dtype => &l.cast_into(r_dtype) - r,
1429 (l, r) => unreachable!("subtract {}, {}", l, r),
1430 },
1431 }
1432 }
1433}
1434
1435impl SubAssign<&Array> for Array {
1436 fn sub_assign(&mut self, other: &Array) {
1437 let diff = &*self - other;
1438 *self = diff;
1439 }
1440}
1441
1442impl SubAssign<Number> for Array {
1443 fn sub_assign(&mut self, rhs: Number) {
1444 *self = &*self - rhs;
1445 }
1446}
1447
1448impl Mul for &Array {
1449 type Output = Array;
1450
1451 fn mul(self, other: &Array) -> Self::Output {
1452 use Array::*;
1453 match (self, other) {
1454 (Bool(l), Bool(r)) => Bool(l * r),
1455 (C32(l), C32(r)) => C32(l * r),
1456 (C64(l), C64(r)) => C64(l * r),
1457 (F32(l), F32(r)) => F32(l * r),
1458 (F64(l), F64(r)) => F64(l * r),
1459 (I16(l), I16(r)) => I16(l * r),
1460 (I32(l), I32(r)) => I32(l * r),
1461 (I64(l), I64(r)) => I64(l * r),
1462 (U8(l), U8(r)) => U8(l * r),
1463 (U16(l), U16(r)) => U16(l * r),
1464 (U32(l), U32(r)) => U32(l * r),
1465 (U64(l), U64(r)) => U64(l * r),
1466 (l, r) => match (l.dtype(), r.dtype()) {
1467 (l_dtype, r_dtype) if l_dtype > r_dtype => l * &r.cast_into(l_dtype),
1468 (l_dtype, r_dtype) if l_dtype < r_dtype => &l.cast_into(r_dtype) * r,
1469 (l, r) => unreachable!("multiply {}, {}", l, r),
1470 },
1471 }
1472 }
1473}
1474
1475impl Mul<Number> for &Array {
1476 type Output = Array;
1477
1478 fn mul(self, rhs: Number) -> Self::Output {
1479 use number_general::Complex;
1480 match (self, rhs) {
1481 (Array::Bool(l), Number::Bool(r)) => Array::Bool((l.deref() * bool::from(r)).into()),
1482
1483 (Array::F32(l), Number::Float(Float::F32(r))) => Array::F32((l.deref() * r).into()),
1484 (Array::F64(l), Number::Float(Float::F64(r))) => Array::F64((l.deref() * r).into()),
1485
1486 (Array::C32(l), Number::Complex(Complex::C32(r))) => Array::C32((l.deref() * r).into()),
1487 (Array::C64(l), Number::Complex(Complex::C64(r))) => Array::C64((l.deref() * r).into()),
1488
1489 (Array::I16(l), Number::Int(Int::I16(r))) => Array::I16((l.deref() * r).into()),
1490 (Array::I32(l), Number::Int(Int::I32(r))) => Array::I32((l.deref() * r).into()),
1491 (Array::I64(l), Number::Int(Int::I64(r))) => Array::I64((l.deref() * r).into()),
1492
1493 (Array::U8(l), Number::UInt(UInt::U8(r))) => Array::U8((l.deref() * r).into()),
1494 (Array::U16(l), Number::UInt(UInt::U16(r))) => Array::U16((l.deref() * r).into()),
1495 (Array::U32(l), Number::UInt(UInt::U32(r))) => Array::U32((l.deref() * r).into()),
1496 (Array::U64(l), Number::UInt(UInt::U64(r))) => Array::U64((l.deref() * r).into()),
1497
1498 (l, r) => match (l.dtype(), r.class()) {
1499 (l_dtype, r_dtype) if l_dtype > r_dtype => l * r.into_type(l_dtype),
1500 (l_dtype, r_dtype) if l_dtype < r_dtype => &l.cast_into(r_dtype) * r,
1501 (l, r) => unreachable!("subtract {}, {}", l, r),
1502 },
1503 }
1504 }
1505}
1506
1507impl MulAssign<&Array> for Array {
1508 fn mul_assign(&mut self, other: &Array) {
1509 let product = &*self * other;
1510 *self = product;
1511 }
1512}
1513
1514impl MulAssign<Number> for Array {
1515 fn mul_assign(&mut self, rhs: Number) {
1516 *self = &*self * rhs;
1517 }
1518}
1519
1520impl Div for &Array {
1521 type Output = Array;
1522
1523 fn div(self, other: &Array) -> Self::Output {
1524 use Array::*;
1525 match (self, other) {
1526 (Bool(l), Bool(r)) => Bool(l / r),
1527 (C32(l), C32(r)) => C32(l / r),
1528 (C64(l), C64(r)) => C64(l / r),
1529 (F32(l), F32(r)) => F32(l / r),
1530 (F64(l), F64(r)) => F64(l / r),
1531 (I16(l), I16(r)) => I16(l / r),
1532 (I32(l), I32(r)) => I32(l / r),
1533 (I64(l), I64(r)) => I64(l / r),
1534 (U8(l), U8(r)) => U8(l / r),
1535 (U16(l), U16(r)) => U16(l / r),
1536 (U32(l), U32(r)) => U32(l / r),
1537 (U64(l), U64(r)) => U64(l / r),
1538 (l, r) => match (l.dtype(), r.dtype()) {
1539 (l_dtype, r_dtype) if l_dtype > r_dtype => l / &r.cast_into(l_dtype),
1540 (l_dtype, r_dtype) if l_dtype < r_dtype => &l.cast_into(r_dtype) / r,
1541 (l, r) => unreachable!("divide {}, {}", l, r),
1542 },
1543 }
1544 }
1545}
1546
1547impl Div<Number> for &Array {
1548 type Output = Array;
1549
1550 fn div(self, rhs: Number) -> Self::Output {
1551 use number_general::Complex;
1552 match (self, rhs) {
1553 (Array::Bool(l), Number::Bool(r)) => Array::Bool((l.deref() / bool::from(r)).into()),
1554
1555 (Array::F32(l), Number::Float(Float::F32(r))) => Array::F32((l.deref() / r).into()),
1556 (Array::F64(l), Number::Float(Float::F64(r))) => Array::F64((l.deref() / r).into()),
1557
1558 (Array::C32(l), Number::Complex(Complex::C32(r))) => Array::C32((l.deref() / r).into()),
1559 (Array::C64(l), Number::Complex(Complex::C64(r))) => Array::C64((l.deref() / r).into()),
1560
1561 (Array::I16(l), Number::Int(Int::I16(r))) => Array::I16((l.deref() / r).into()),
1562 (Array::I32(l), Number::Int(Int::I32(r))) => Array::I32((l.deref() / r).into()),
1563 (Array::I64(l), Number::Int(Int::I64(r))) => Array::I64((l.deref() / r).into()),
1564
1565 (Array::U8(l), Number::UInt(UInt::U8(r))) => Array::U8((l.deref() / r).into()),
1566 (Array::U16(l), Number::UInt(UInt::U16(r))) => Array::U16((l.deref() / r).into()),
1567 (Array::U32(l), Number::UInt(UInt::U32(r))) => Array::U32((l.deref() / r).into()),
1568 (Array::U64(l), Number::UInt(UInt::U64(r))) => Array::U64((l.deref() / r).into()),
1569
1570 (l, r) => match (l.dtype(), r.class()) {
1571 (l_dtype, r_dtype) if l_dtype > r_dtype => l / r.into_type(l_dtype),
1572 (l_dtype, r_dtype) if l_dtype < r_dtype => &l.cast_into(r_dtype) / r,
1573 (l, r) => unreachable!("subtract {}, {}", l, r),
1574 },
1575 }
1576 }
1577}
1578
1579impl DivAssign<&Array> for Array {
1580 fn div_assign(&mut self, other: &Array) {
1581 let div = &*self / other;
1582 *self = div;
1583 }
1584}
1585
1586impl DivAssign<Number> for Array {
1587 fn div_assign(&mut self, rhs: Number) {
1588 *self = &*self / rhs;
1589 }
1590}
1591
1592impl<T: af::HasAfEnum> CastFrom<Array> for ArrayExt<T> {
1593 fn cast_from(array: Array) -> ArrayExt<T> {
1594 use Array::*;
1595 match array {
1596 Bool(b) => b.type_cast(),
1597 C32(c) => c.type_cast(),
1598 C64(c) => c.type_cast(),
1599 F32(f) => f.type_cast(),
1600 F64(f) => f.type_cast(),
1601 I16(i) => i.type_cast(),
1602 I32(i) => i.type_cast(),
1603 I64(i) => i.type_cast(),
1604 U8(u) => u.type_cast(),
1605 U16(u) => u.type_cast(),
1606 U32(u) => u.type_cast(),
1607 U64(u) => u.type_cast(),
1608 }
1609 }
1610}
1611
1612as_type!(Array, Bool, ArrayExt<bool>);
1613as_type!(Array, C32, ArrayExt<Complex<f32>>);
1614as_type!(Array, C64, ArrayExt<Complex<f64>>);
1615as_type!(Array, F32, ArrayExt<f32>);
1616as_type!(Array, F64, ArrayExt<f64>);
1617as_type!(Array, I16, ArrayExt<i16>);
1618as_type!(Array, I32, ArrayExt<i32>);
1619as_type!(Array, I64, ArrayExt<i64>);
1620as_type!(Array, U8, ArrayExt<u8>);
1621as_type!(Array, U16, ArrayExt<u16>);
1622as_type!(Array, U32, ArrayExt<u32>);
1623as_type!(Array, U64, ArrayExt<u64>);
1624
1625impl<T: af::HasAfEnum> From<Vec<T>> for Array
1626where
1627 Array: From<ArrayExt<T>>,
1628{
1629 fn from(values: Vec<T>) -> Self {
1630 ArrayExt::from(values.as_slice()).into()
1631 }
1632}
1633
1634impl<T: af::HasAfEnum> From<&[T]> for Array
1635where
1636 Array: From<ArrayExt<T>>,
1637{
1638 fn from(values: &[T]) -> Self {
1639 ArrayExt::from(values).into()
1640 }
1641}
1642
1643impl<T: af::HasAfEnum> FromIterator<T> for Array
1644where
1645 Array: From<ArrayExt<T>>,
1646{
1647 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
1648 ArrayExt::from_iter(iter).into()
1649 }
1650}
1651
1652impl From<Vec<Number>> for Array {
1653 fn from(elements: Vec<Number>) -> Self {
1654 use {ComplexType as CT, FloatType as FT, IntType as IT, NumberType as NT, UIntType as UT};
1655
1656 let dtype = elements.iter().map(|n| n.class()).fold(NT::Bool, Ord::max);
1657
1658 let array = match dtype {
1659 NT::Bool => Self::Bool(array_from(elements)),
1660 NT::Complex(ct) => match ct {
1661 CT::C32 => Self::C32(array_from(elements)),
1662 _ => Self::C64(array_from(elements)),
1663 },
1664 NT::Float(ft) => match ft {
1665 FT::F32 => Self::F32(array_from(elements)),
1666 _ => Self::F64(array_from(elements)),
1667 },
1668 NT::Int(it) => match it {
1669 IT::I8 => Self::I16(array_from(elements)),
1670 IT::I16 => Self::I16(array_from(elements)),
1671 IT::I32 => Self::I32(array_from(elements)),
1672 _ => Self::I64(array_from(elements)),
1673 },
1674 NT::UInt(ut) => match ut {
1675 UT::U8 => Self::U8(array_from(elements)),
1676 UT::U16 => Self::U16(array_from(elements)),
1677 UT::U32 => Self::U32(array_from(elements)),
1678 _ => Self::U64(array_from(elements)),
1679 },
1680 NT::Number => Self::F64(array_from(elements)),
1681 };
1682
1683 array
1684 }
1685}
1686
1687impl<'de> Deserialize<'de> for Array {
1688 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> std::result::Result<Self, D::Error> {
1689 Vec::<Number>::deserialize(deserializer).map(Self::from)
1690 }
1691}
1692
1693fn array_from<T: af::HasAfEnum + CastFrom<Number>>(elements: Vec<Number>) -> ArrayExt<T> {
1694 elements
1695 .into_iter()
1696 .map(|n| n.cast_into())
1697 .collect::<Vec<T>>()
1698 .as_slice()
1699 .into()
1700}
1701
1702impl Serialize for Array {
1703 fn serialize<S: Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> {
1704 self.to_vec().serialize(serializer)
1705 }
1706}
1707
1708#[async_trait]
1709impl de::FromStream for Array {
1710 type Context = ();
1711
1712 async fn from_stream<D: de::Decoder>(
1713 _: (),
1714 decoder: &mut D,
1715 ) -> std::result::Result<Self, D::Error> {
1716 decoder.decode_seq(ArrayVisitor).await
1717 }
1718}
1719
1720impl<'en> en::ToStream<'en> for Array {
1721 fn to_stream<E: en::Encoder<'en>>(
1722 &'en self,
1723 encoder: E,
1724 ) -> std::result::Result<E::Ok, E::Error> {
1725 use en::IntoStream;
1726
1727 match self {
1728 Self::Bool(array) => (DType::Bool, array).into_stream(encoder),
1729 Self::C32(array) => (DType::C32, array.re(), array.im()).into_stream(encoder),
1730 Self::C64(array) => (DType::C64, array.re(), array.im()).into_stream(encoder),
1731 Self::F32(array) => (DType::F32, array).into_stream(encoder),
1732 Self::F64(array) => (DType::F64, array).into_stream(encoder),
1733 Self::I16(array) => (DType::I16, array).into_stream(encoder),
1734 Self::I32(array) => (DType::I32, array).into_stream(encoder),
1735 Self::I64(array) => (DType::I64, array).into_stream(encoder),
1736 Self::U8(array) => (DType::U8, array).into_stream(encoder),
1737 Self::U16(array) => (DType::U16, array).into_stream(encoder),
1738 Self::U32(array) => (DType::U32, array).into_stream(encoder),
1739 Self::U64(array) => (DType::U64, array).into_stream(encoder),
1740 }
1741 }
1742}
1743
1744impl<'en> en::IntoStream<'en> for Array {
1745 fn into_stream<E: en::Encoder<'en>>(self, encoder: E) -> std::result::Result<E::Ok, E::Error> {
1746 match self {
1747 Self::Bool(array) => (DType::Bool, array).into_stream(encoder),
1748 Self::C32(array) => (DType::C32, array.re(), array.im()).into_stream(encoder),
1749 Self::C64(array) => (DType::C64, array.re(), array.im()).into_stream(encoder),
1750 Self::F32(array) => (DType::F32, array).into_stream(encoder),
1751 Self::F64(array) => (DType::F64, array).into_stream(encoder),
1752 Self::I16(array) => (DType::I16, array).into_stream(encoder),
1753 Self::I32(array) => (DType::I32, array).into_stream(encoder),
1754 Self::I64(array) => (DType::I64, array).into_stream(encoder),
1755 Self::U8(array) => (DType::U8, array).into_stream(encoder),
1756 Self::U16(array) => (DType::U16, array).into_stream(encoder),
1757 Self::U32(array) => (DType::U32, array).into_stream(encoder),
1758 Self::U64(array) => (DType::U64, array).into_stream(encoder),
1759 }
1760 }
1761}
1762
1763impl fmt::Debug for Array {
1764 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1765 match self {
1766 Self::Bool(array) => fmt::Debug::fmt(array, f),
1767 Self::C32(array) => fmt::Debug::fmt(array, f),
1768 Self::C64(array) => fmt::Debug::fmt(array, f),
1769 Self::F32(array) => fmt::Debug::fmt(array, f),
1770 Self::F64(array) => fmt::Debug::fmt(array, f),
1771 Self::I16(array) => fmt::Debug::fmt(array, f),
1772 Self::I32(array) => fmt::Debug::fmt(array, f),
1773 Self::I64(array) => fmt::Debug::fmt(array, f),
1774 Self::U8(array) => fmt::Debug::fmt(array, f),
1775 Self::U16(array) => fmt::Debug::fmt(array, f),
1776 Self::U32(array) => fmt::Debug::fmt(array, f),
1777 Self::U64(array) => fmt::Debug::fmt(array, f),
1778 }
1779 }
1780}
1781
1782impl fmt::Display for Array {
1783 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1784 match self {
1785 Self::Bool(array) => fmt::Display::fmt(array, f),
1786 Self::C32(array) => fmt::Display::fmt(array, f),
1787 Self::C64(array) => fmt::Display::fmt(array, f),
1788 Self::F32(array) => fmt::Display::fmt(array, f),
1789 Self::F64(array) => fmt::Display::fmt(array, f),
1790 Self::I16(array) => fmt::Display::fmt(array, f),
1791 Self::I32(array) => fmt::Display::fmt(array, f),
1792 Self::I64(array) => fmt::Display::fmt(array, f),
1793 Self::U8(array) => fmt::Display::fmt(array, f),
1794 Self::U16(array) => fmt::Display::fmt(array, f),
1795 Self::U32(array) => fmt::Display::fmt(array, f),
1796 Self::U64(array) => fmt::Display::fmt(array, f),
1797 }
1798 }
1799}
1800
1801struct ArrayVisitor;
1802
1803impl ArrayVisitor {
1804 async fn visit_array<A: de::SeqAccess, T: af::HasAfEnum>(
1805 seq: &mut A,
1806 ) -> std::result::Result<ArrayExt<T>, A::Error>
1807 where
1808 ArrayExt<T>: de::FromStream<Context = ()>,
1809 {
1810 seq.next_element(())
1811 .await?
1812 .ok_or_else(|| de::Error::custom("missing array"))
1813 }
1814}
1815
1816#[async_trait]
1817impl de::Visitor for ArrayVisitor {
1818 type Value = Array;
1819
1820 fn expecting() -> &'static str {
1821 "a numeric array"
1822 }
1823
1824 async fn visit_seq<A: de::SeqAccess>(
1825 self,
1826 mut seq: A,
1827 ) -> std::result::Result<Self::Value, A::Error> {
1828 let dtype = seq
1829 .next_element::<DType>(())
1830 .await?
1831 .ok_or_else(|| de::Error::custom("missing array data type"))?;
1832
1833 match dtype {
1834 DType::Bool => Self::visit_array(&mut seq).map_ok(Array::Bool).await,
1835 DType::C32 => {
1836 let re = Self::visit_array(&mut seq).await?;
1837 let im = Self::visit_array(&mut seq).await?;
1838 Ok(Array::C32(ArrayExt::from((re, im))))
1839 }
1840 DType::C64 => {
1841 let re = Self::visit_array(&mut seq).await?;
1842 let im = Self::visit_array(&mut seq).await?;
1843 Ok(Array::C64(ArrayExt::from((re, im))))
1844 }
1845 DType::F32 => Self::visit_array(&mut seq).map_ok(Array::F32).await,
1846 DType::F64 => Self::visit_array(&mut seq).map_ok(Array::F64).await,
1847 DType::I16 => Self::visit_array(&mut seq).map_ok(Array::I16).await,
1848 DType::I32 => Self::visit_array(&mut seq).map_ok(Array::I32).await,
1849 DType::I64 => Self::visit_array(&mut seq).map_ok(Array::I64).await,
1850 DType::U8 => Self::visit_array(&mut seq).map_ok(Array::U8).await,
1851 DType::U16 => Self::visit_array(&mut seq).map_ok(Array::U16).await,
1852 DType::U32 => Self::visit_array(&mut seq).map_ok(Array::U32).await,
1853 DType::U64 => Self::visit_array(&mut seq).map_ok(Array::U64).await,
1854 }
1855 }
1856}
1857
1858#[derive(Clone, Copy, Eq, PartialEq, num_derive::FromPrimitive, num_derive::ToPrimitive)]
1859enum DType {
1860 Bool,
1861 C32,
1862 C64,
1863 F32,
1864 F64,
1865 I16,
1866 I32,
1867 I64,
1868 U8,
1869 U16,
1870 U32,
1871 U64,
1872}
1873
1874#[async_trait]
1875impl de::FromStream for DType {
1876 type Context = ();
1877
1878 async fn from_stream<D: de::Decoder>(
1879 cxt: (),
1880 decoder: &mut D,
1881 ) -> std::result::Result<Self, D::Error> {
1882 let dtype = u8::from_stream(cxt, decoder).await?;
1883 Self::from_u8(dtype).ok_or_else(|| de::Error::invalid_value(dtype, "an array data type"))
1884 }
1885}
1886
1887impl<'en> en::IntoStream<'en> for DType {
1888 fn into_stream<E: en::Encoder<'en>>(self, encoder: E) -> std::result::Result<E::Ok, E::Error> {
1889 self.to_u8().into_stream(encoder)
1890 }
1891}
1892
1893pub(crate) fn reduce_block<T, B, R>(block: &ArrayExt<T>, stride: u64, reduce: &mut R) -> ArrayExt<B>
1894where
1895 T: af::HasAfEnum,
1896 B: af::HasAfEnum,
1897 R: FnMut(af::Array<T>) -> ArrayExt<B>,
1898{
1899 assert_eq!(block.len() as u64 % stride, 0);
1900 let shape = af::Dim4::new(&[stride, block.len() as u64 / stride, 1, 1]);
1901 let block = af::moddims(&block, shape);
1902 let reduced = reduce(block.into());
1903 let shape = af::Dim4::new(&[reduced.len() as u64, 1, 1, 1]);
1904 af::moddims(&reduced, shape).into()
1905}
1906
1907#[cfg(test)]
1908mod tests {
1909 use super::*;
1910
1911 #[test]
1912 fn test_get_value() {
1913 assert_eq!(Array::from(&[1, 2, 3][..]).get_value(1), Number::from(2));
1914 }
1915
1916 #[test]
1917 fn test_get() {
1918 let arr = Array::from(vec![1, 2, 3].as_slice());
1919 let actual = arr.get(&(&[1, 2][..]).into());
1920 let expected = Array::from(&[2, 3][..]);
1921 assert_eq!(actual, expected)
1922 }
1923
1924 #[test]
1925 fn test_set() {
1926 let mut actual = Array::from(&[1, 2, 3][..]);
1927 actual
1928 .set(&(&[1, 2][..]).into(), &Array::from(&[4, 5][..]))
1929 .unwrap();
1930
1931 let expected = Array::from(&[1, 4, 5][..]);
1932 assert_eq!(actual, expected)
1933 }
1934
1935 #[test]
1936 fn test_add() {
1937 let a: Array = [1, 2, 3][..].into();
1938 let b: Array = [1][..].into();
1939 assert_eq!(&a + &b, [2, 3, 4][..].into());
1940
1941 let b: Array = [3, 2, 1][..].into();
1942 assert_eq!(&a + &b, [4, 4, 4][..].into());
1943
1944 assert_eq!(&b + Number::from(1), [4, 3, 2][..].into());
1945 }
1946
1947 #[test]
1948 fn test_add_float() {
1949 let a: Array = [1, 2, 3][..].into();
1950 let b: Array = [2.0][..].into();
1951 assert_eq!(&a + &b, [3.0, 4.0, 5.0][..].into());
1952
1953 let b: Array = [-1., -4., 4.][..].into();
1954 assert_eq!(&a + &b, [0., -2., 7.][..].into());
1955
1956 assert_eq!(&b + Number::from(3), [2, -1, 7][..].into());
1957 }
1958
1959 #[test]
1960 fn test_gte() {
1961 let a: Array = [0, 1, 2][..].into();
1962 let b: Array = [1][..].into();
1963 assert_eq!(a.gte(&b), [false, true, true][..].into());
1964 assert_eq!(a.gte_const(Number::from(1)), [false, true, true][..].into());
1965 }
1966
1967 #[test]
1968 fn test_sub() {
1969 let a: Array = [1, 2, 3][..].into();
1970 let b: Array = [1][..].into();
1971 assert_eq!(&a - &b, [0, 1, 2][..].into());
1972
1973 let b: Array = [3, 2, 1][..].into();
1974 assert_eq!(&a - &b, [-2, 0, 2][..].into());
1975 }
1976
1977 #[test]
1978 fn test_sub_float() {
1979 let a: Array = [1, 2, 3][..].into();
1980 let b: Array = [2.0][..].into();
1981 assert_eq!(&a - &b, [-1.0, 0., 1.0][..].into());
1982
1983 let b: Array = [-1., -4., 4.][..].into();
1984 assert_eq!(&a - &b, [2., 6., -1.][..].into());
1985 }
1986
1987 #[test]
1988 fn test_mul() {
1989 let a: Array = [1, 2, 3][..].into();
1990 let b: Array = [2][..].into();
1991 assert_eq!(&a * &b, [2, 4, 6][..].into());
1992
1993 let b: Array = [5, 4, 3][..].into();
1994 assert_eq!(&a * &b, [5, 8, 9][..].into());
1995 }
1996
1997 #[test]
1998 fn test_mul_const() {
1999 let a: Array = [1, 2, 3][..].into();
2000 let b: Number = 2f32.into();
2001 assert_eq!(&a * b, [2.0, 4.0, 6.0][..].into());
2002 }
2003
2004 #[test]
2005 fn test_mul_float() {
2006 let a: Array = [1.0f32, 2.0f32, 3.0f32][..].into();
2007 let b: Array = [2.0f32][..].into();
2008 assert_eq!(&a * &b, [2.0, 4.0, 6.0][..].into());
2009
2010 let b: Array = [-1., -4., 4.][..].into();
2011 assert_eq!(&a * &b, [-1., -8., 12.][..].into());
2012 }
2013
2014 #[test]
2015 fn test_div() {
2016 let a: Array = [1, 2, 3][..].into();
2017 let b: Array = [2.0][..].into();
2018 assert_eq!(&a / &b, [0.5, 1.0, 1.5][..].into());
2019
2020 let b: Array = [-1., -4., 4.][..].into();
2021 assert_eq!(&a / &b, [-1., -0.5, 0.75][..].into());
2022 }
2023
2024 #[test]
2025 fn test_pow() {
2026 let a: Array = [1, 2, 3][..].into();
2027 let b: Array = [2][..].into();
2028 assert_eq!(a.pow(&b), [1.0, 4.0, 9.0][..].into());
2029
2030 let a: Array = [1, 2, 3][..].into();
2031 let b: Array = [2.0][..].into();
2032 assert_eq!(a.pow(&b), [1.0, 4.0, 9.0][..].into());
2033
2034 let a: Array = [1.0, 2.0, 3.0][..].into();
2035 let b: Array = [2][..].into();
2036 assert_eq!(a.pow(&b), [1.0, 4.0, 9.0][..].into());
2037 }
2038
2039 #[test]
2040 fn test_min_and_max() {
2041 let a: Array = [3, 1, 4, 2][..].into();
2042 assert_eq!(a.min(), 1.into());
2043 assert_eq!(a.max(), 4.into());
2044 }
2045
2046 #[test]
2047 fn test_sum() {
2048 let a: Array = [1, 2, 3, 4][..].into();
2049 assert_eq!(a.sum(), 10.into());
2050 }
2051
2052 #[test]
2053 fn test_product() {
2054 let a: Array = [1, 2, 3, 4][..].into();
2055 assert_eq!(a.product(), 24.into());
2056 }
2057
2058 #[test]
2059 fn test_argsort() {
2060 let a = Array::random_uniform(FloatType::F32, 10);
2061 let (sorted, indices) = a.argsort(true).expect("argsort");
2062 assert_eq!(sorted, a.get(&indices))
2063 }
2064
2065 #[tokio::test]
2066 async fn test_serialization() {
2067 let expected: Array = [1, 2, 3, 4][..].into();
2068 let serialized = tbon::en::encode(&expected).expect("encode");
2069 let actual = tbon::de::try_decode((), serialized).await.expect("decode");
2070 assert!(expected.eq(&actual).all());
2071 }
2072}