1#![allow(clippy::type_complexity)]
2use std::fmt;
6use std::marker::PhantomData;
7
8use crate::access::*;
9use crate::buffer::BufferInstance;
10use crate::ops::*;
11use crate::platform::PlatformInstance;
12#[cfg(feature = "complex")]
13use crate::Complex;
14use crate::{
15 axes, range_shape, shape, strides_for, ArrayAccess, Axes, AxisRange, BufferConverter, Constant,
16 Convert, Error, Float, Number, Platform, Range, Real, Shape,
17};
18
19pub struct Array<T, A, P> {
20 shape: Shape,
21 access: A,
22 platform: P,
23 dtype: PhantomData<T>,
24}
25
26impl<T, A: Clone, P: Clone> Clone for Array<T, A, P> {
27 fn clone(&self) -> Self {
28 Self {
29 shape: self.shape.clone(),
30 access: self.access.clone(),
31 platform: self.platform.clone(),
32 dtype: self.dtype,
33 }
34 }
35}
36
37impl<T, A, P> Array<T, A, P> {
38 fn apply<O, OT, Op>(self, op: Op) -> Result<Array<OT, AccessOp<O, P>, P>, Error>
39 where
40 P: Copy,
41 Op: Fn(P, A) -> Result<AccessOp<O, P>, Error>,
42 {
43 let access = (op)(self.platform, self.access)?;
44
45 Ok(Array {
46 shape: self.shape,
47 access,
48 platform: self.platform,
49 dtype: PhantomData,
50 })
51 }
52
53 fn reduce_axes<'a, Op>(
54 self,
55 mut axes: Axes,
56 keepdims: bool,
57 op: Op,
58 ) -> Result<Array<T, AccessOp<P::Op, P>, P>, Error>
59 where
60 T: Number,
61 A: Access<T>,
62 P: Transform<A, T> + ReduceAxes<Accessor<'a, T>, T>,
63 Op: Fn(P, Accessor<'a, T>, usize) -> Result<AccessOp<P::Op, P>, Error>,
64 Accessor<'a, T>: From<A> + From<AccessOp<P::Transpose, P>>,
65 {
66 axes.sort();
67 axes.dedup();
68
69 let platform = P::select(self.size());
70 let stride = axes.iter().copied().map(|x| self.shape[x]).product();
71 let shape = reduce_axes(&self.shape, &axes, keepdims)?;
72
73 let access = permute_for_reduce(self.platform, self.access, self.shape, axes)?;
74 let access = (op)(self.platform, access, stride)?;
75
76 Ok(Array {
77 access,
78 shape,
79 platform,
80 dtype: PhantomData,
81 })
82 }
83
84 pub fn access(&self) -> &A {
85 &self.access
86 }
87
88 pub fn into_access(self) -> A {
89 self.access
90 }
91}
92
93impl<T, L, P> Array<T, L, P> {
94 fn apply_dual<O, OT, R, Op>(
95 self,
96 other: Array<T, R, P>,
97 op: Op,
98 ) -> Result<Array<OT, AccessOp<O, P>, P>, Error>
99 where
100 P: Copy,
101 Op: Fn(P, L, R) -> Result<AccessOp<O, P>, Error>,
102 {
103 let access = (op)(self.platform, self.access, other.access)?;
104
105 Ok(Array {
106 shape: self.shape,
107 access,
108 platform: self.platform,
109 dtype: PhantomData,
110 })
111 }
112}
113
114impl<'a, T: Number> Array<T, Accessor<'a, T>, Platform> {
116 pub fn from<A, P>(array: Array<T, A, P>) -> Self
117 where
118 A: Into<Accessor<'a, T>>,
119 Platform: From<P>,
120 {
121 Self {
122 shape: array.shape,
123 access: array.access.into(),
124 platform: array.platform.into(),
125 dtype: array.dtype,
126 }
127 }
128}
129
130impl<T, B, P> Array<T, AccessBuf<B>, P>
131where
132 T: Number,
133 B: BufferInstance<T>,
134 P: PlatformInstance,
135{
136 fn new_inner(platform: P, buffer: B, shape: Shape) -> Result<Self, Error> {
137 if !shape.is_empty() && shape.iter().product::<usize>() == buffer.len() {
138 let access = buffer.into();
139
140 Ok(Self {
141 shape,
142 access,
143 platform,
144 dtype: PhantomData,
145 })
146 } else {
147 Err(Error::bounds(format!(
148 "cannot construct an array with shape {shape:?} from a buffer of size {}",
149 buffer.len(),
150 )))
151 }
152 }
153
154 pub fn convert<'a, FB>(buffer: FB, shape: Shape) -> Result<Self, Error>
155 where
156 FB: Into<BufferConverter<'a, T>>,
157 P: Convert<T, Buffer = B>,
158 {
159 let buffer = buffer.into();
160 let platform = P::select(buffer.len());
161 let buffer = platform.convert(buffer)?;
162 Self::new_inner(platform, buffer, shape)
163 }
164
165 pub fn new(buffer: B, shape: Shape) -> Result<Self, Error> {
166 let platform = P::select(buffer.len());
167 Self::new_inner(platform, buffer, shape)
168 }
169}
170
171impl<T, P> Array<T, AccessBuf<P::Buffer>, P>
172where
173 T: Number,
174 P: Constant<T>,
175{
176 pub fn constant(value: T, shape: Shape) -> Result<Self, Error> {
177 if !shape.is_empty() {
178 let size = shape.iter().product();
179 let platform = P::select(size);
180 let buffer = platform.constant(value, size)?;
181 let access = buffer.into();
182
183 Ok(Self {
184 shape,
185 access,
186 platform,
187 dtype: PhantomData,
188 })
189 } else {
190 Err(Error::bounds(
191 "cannot construct an array with an empty shape".to_string(),
192 ))
193 }
194 }
195}
196
197impl<T, A, P> Array<T, A, P>
199where
200 T: Number,
201 A: Access<T>,
202 P: Convert<T>,
203{
204 pub fn copy(&self) -> Result<Array<T, AccessBuf<P::Buffer>, P>, Error> {
205 let buffer = self.buffer().and_then(|buf| self.platform.convert(buf))?;
206
207 Ok(Array {
208 shape: self.shape.clone(),
209 access: buffer.into(),
210 platform: self.platform,
211 dtype: self.dtype,
212 })
213 }
214}
215
216impl<T, A, P> Array<T, A, P>
218where
219 T: Number,
220 A: Access<T>,
221 P: Transform<A, T>,
222 P: ConstructConcat<AccessOp<<P as Transform<A, T>>::Transpose, P>, T>,
223 P: Transform<
224 AccessOp<<P as ConstructConcat<AccessOp<<P as Transform<A, T>>::Transpose, P>, T>>::Op, P>,
225 T,
226 >,
227{
228 pub fn stack<AS>(arrays: AS, axis: usize) -> Result<Array<T, impl Access<T>, P>, Error>
229 where
230 AS: IntoIterator<Item = Self>,
231 {
232 let arrays = arrays
233 .into_iter()
234 .map(|arr| arr.unsqueeze(axes![axis]))
235 .collect::<Result<Vec<_>, Error>>()?;
236
237 Array::transpose_concat(arrays, axis)
238 }
239
240 pub fn transpose_concat(
241 arrays: Vec<Self>,
242 axis: usize,
243 ) -> Result<Array<T, impl Access<T>, P>, Error> {
244 let shape = if let Some(first) = arrays.first() {
245 let shape = first.shape();
246 if axis < shape.len() {
247 Ok(shape)
248 } else {
249 Err(Error::bounds(format!("{first:?} has no axis {axis}")))
250 }
251 } else {
252 Err(Error::bounds(
253 "cannot concatenate an empty list of arrays".to_string(),
254 ))
255 }?;
256
257 for array in arrays.iter().skip(1) {
258 if array.ndim() == shape.len() {
259 for (x, (dim, a_dim)) in shape.iter().zip(array.shape()).enumerate() {
260 if x != axis && dim != a_dim {
261 return Err(Error::bounds(format!(
262 "cannot concatenate {:?} with {:?} at axis {axis}",
263 shape,
264 array.shape()
265 )));
266 }
267 }
268 } else {
269 return Err(Error::bounds(format!(
270 "cannot concatenate {:?} with {:?}",
271 shape,
272 array.shape()
273 )));
274 }
275 }
276
277 let mut permutation: Axes = (0..shape.len()).collect();
278 permutation.swap(0, axis);
279
280 let arrays = arrays
281 .into_iter()
282 .map(|array| array.transpose(permutation.clone()))
283 .collect::<Result<Vec<Array<T, _, P>>, Error>>()?;
284
285 Array::concat(arrays)?.transpose(permutation)
286 }
287}
288
289impl<T, A, P> Array<T, A, P>
290where
291 T: Number,
292 A: Access<T>,
293 P: ConstructConcat<A, T>,
294{
295 pub fn concat(arrays: Vec<Self>) -> Result<Array<T, AccessOp<P::Op, P>, P>, Error> {
296 let mut array_iter = arrays.iter();
297 let first = array_iter.next();
298
299 if let Some(first) = first {
300 let mut shape = Shape::from_slice(first.shape());
301 for next in array_iter {
302 if next.ndim() != shape.len()
303 || (shape.len() > 1 && shape[1..] != next.shape()[1..])
304 {
305 return Err(Error::bounds(format!(
306 "cannot concatenate shapes {:?} and {:?}",
307 shape,
308 next.shape()
309 )));
310 } else {
311 shape[0] += next.shape()[0];
312 }
313 }
314
315 Self::concat_inner(arrays, shape)
316 } else {
317 Err(Error::bounds(
318 "cannot concatenate an empty list of arrays".into(),
319 ))
320 }
321 }
322
323 fn concat_inner(
324 arrays: Vec<Array<T, A, P>>,
325 shape: Shape,
326 ) -> Result<Array<T, AccessOp<P::Op, P>, P>, Error> {
327 let platform = P::select(shape.iter().product());
328
329 let data = arrays
330 .into_iter()
331 .map(|array| array.into_access())
332 .collect();
333
334 platform.concat(data).map(|access| Array {
335 shape,
336 access,
337 platform,
338 dtype: PhantomData,
339 })
340 }
341}
342
343impl<T: Number, P: PlatformInstance> Array<T, AccessOp<P::Range, P>, P>
344where
345 P: ConstructRange<T>,
346{
347 pub fn range(start: T, stop: T, shape: Shape) -> Result<Self, Error> {
348 let size = shape.iter().product();
349 let platform = P::select(size);
350
351 platform.range(start, stop, size).map(|access| Self {
352 shape,
353 access,
354 platform,
355 dtype: PhantomData,
356 })
357 }
358}
359
360impl<P: PlatformInstance> Array<f32, AccessOp<P::Normal, P>, P>
361where
362 P: Random,
363{
364 pub fn random_normal(size: usize) -> Result<Self, Error> {
365 let platform = P::select(size);
366 let shape = shape![size];
367
368 platform.random_normal(size).map(|access| Self {
369 shape,
370 access,
371 platform,
372 dtype: PhantomData,
373 })
374 }
375}
376
377impl<P: PlatformInstance> Array<f32, AccessOp<P::Uniform, P>, P>
378where
379 P: Random,
380{
381 pub fn random_uniform(size: usize) -> Result<Self, Error> {
382 let platform = P::select(size);
383 let shape = shape![size];
384
385 platform.random_uniform(size).map(|access| Self {
386 shape,
387 access,
388 platform,
389 dtype: PhantomData,
390 })
391 }
392}
393
394impl<T, A, P> Array<T, A, P>
396where
397 T: Number,
398 A: Access<T>,
399 P: PlatformInstance,
400{
401 pub fn as_mut<'a, B>(&'a mut self) -> Array<T, B, P>
402 where
403 A: AccessBorrowMut<'a, T, B>,
404 B: AccessMut<T> + 'a,
405 {
406 Array {
407 shape: Shape::from_slice(&self.shape),
408 access: AccessBorrowMut::borrow_mut(&mut self.access),
409 platform: self.platform,
410 dtype: PhantomData,
411 }
412 }
413
414 pub fn as_ref<'a, B>(&'a self) -> Array<T, B, P>
415 where
416 A: AccessBorrow<'a, T, B>,
417 B: Access<T> + 'a,
418 {
419 Array {
420 shape: Shape::from_slice(&self.shape),
421 access: AccessBorrow::borrow(&self.access),
422 platform: self.platform,
423 dtype: PhantomData,
424 }
425 }
426}
427
428impl<'a, T: Number> ArrayAccess<'a, T> {
431 pub fn unstack(
432 self,
433 axis: usize,
434 ) -> Result<Vec<Array<T, impl Access<T> + 'a, Platform>>, Error> {
435 let dim = self
436 .shape()
437 .get(axis)
438 .copied()
439 .ok_or_else(|| Error::bounds(format!("{self:?} has no axis {axis}")))?;
440
441 let prefix = if axis == 0 {
442 Range::with_capacity(1)
443 } else {
444 self.shape
445 .iter()
446 .take(axis)
447 .copied()
448 .map(|dim| AxisRange::In(0, dim, 1))
449 .collect()
450 };
451
452 (0..dim)
453 .map(|r| {
454 let mut range = prefix.clone();
455 range.push(AxisRange::At(r));
456 range
457 })
458 .map(|r| self.clone().slice(r))
459 .collect()
460 }
461}
462
463pub trait NDArray: Send + Sync {
467 type DType: Number;
469
470 type Platform: PlatformInstance;
472
473 fn ndim(&self) -> usize {
475 self.shape().len()
476 }
477
478 fn size(&self) -> usize {
480 self.shape().iter().product()
481 }
482
483 fn shape(&self) -> &[usize];
485}
486
487impl<T, A, P> NDArray for Array<T, A, P>
488where
489 T: Number,
490 A: Access<T>,
491 P: PlatformInstance,
492{
493 type DType = T;
494 type Platform = P;
495
496 fn shape(&self) -> &[usize] {
497 &self.shape
498 }
499}
500
501pub trait NDArrayAbs: NDArray + Sized {
503 type Output: Access<<Self::DType as Number>::Abs>;
505
506 fn abs(
508 self,
509 ) -> Result<Array<<Self::DType as Number>::Abs, Self::Output, Self::Platform>, Error>;
510}
511
512impl<T, A, P> NDArrayAbs for Array<T, A, P>
513where
514 T: Number,
515 A: Access<T>,
516 P: ElementwiseAbs<A, T>,
517{
518 type Output = AccessOp<P::Op, P>;
519
520 fn abs(self) -> Result<Array<T::Abs, Self::Output, Self::Platform>, Error> {
521 self.apply(|platform, access| platform.abs(access))
522 }
523}
524
525pub trait NDArrayRead: NDArray + fmt::Debug + Sized {
527 fn buffer(&self) -> Result<BufferConverter<'_, Self::DType>, Error>;
529
530 fn into_read(
532 self,
533 ) -> Result<
534 Array<
535 Self::DType,
536 AccessBuf<<Self::Platform as Convert<Self::DType>>::Buffer>,
537 Self::Platform,
538 >,
539 Error,
540 >
541 where
542 Self::Platform: Convert<Self::DType>;
543
544 fn read_value(&self, coord: &[usize]) -> Result<Self::DType, Error>;
546}
547
548impl<T, A, P> NDArrayRead for Array<T, A, P>
549where
550 T: Number,
551 A: Access<T>,
552 P: PlatformInstance,
553{
554 fn buffer(&self) -> Result<BufferConverter<'_, T>, Error> {
555 self.access.read()
556 }
557
558 fn into_read(self) -> Result<Array<Self::DType, AccessBuf<P::Buffer>, Self::Platform>, Error>
559 where
560 P: Convert<T>,
561 {
562 let buffer = self.buffer().and_then(|buf| self.platform.convert(buf))?;
563 debug_assert_eq!(buffer.len(), self.size());
564
565 Ok(Array {
566 shape: self.shape,
567 access: buffer.into(),
568 platform: self.platform,
569 dtype: self.dtype,
570 })
571 }
572
573 fn read_value(&self, coord: &[usize]) -> Result<T, Error> {
574 valid_coord(coord, self.shape())?;
575
576 let strides = strides_for(self.shape(), self.ndim());
577
578 let offset = coord
579 .iter()
580 .zip(strides)
581 .map(|(i, stride)| i * stride)
582 .sum();
583
584 self.access.read_value(offset)
585 }
586}
587
588pub trait NDArrayWrite: NDArray + fmt::Debug + Sized {
590 fn write<O: NDArrayRead<DType = Self::DType>>(&mut self, other: &O) -> Result<(), Error>;
592
593 fn write_value(&mut self, value: Self::DType) -> Result<(), Error>;
595
596 fn write_value_at(&mut self, coord: &[usize], value: Self::DType) -> Result<(), Error>;
598}
599
600impl<T, A, P> NDArrayWrite for Array<T, A, P>
602where
603 T: Number,
604 A: AccessMut<T>,
605 P: PlatformInstance,
606{
607 fn write<O>(&mut self, other: &O) -> Result<(), Error>
608 where
609 O: NDArrayRead<DType = Self::DType>,
610 {
611 same_shape("write", self.shape(), other.shape())?;
612 other.buffer().and_then(|buf| self.access.write(buf))
613 }
614
615 fn write_value(&mut self, value: Self::DType) -> Result<(), Error> {
616 self.access.write_value(value)
617 }
618
619 fn write_value_at(&mut self, coord: &[usize], value: Self::DType) -> Result<(), Error> {
620 valid_coord(coord, self.shape())?;
621
622 let offset = coord
623 .iter()
624 .zip(strides_for(self.shape(), self.ndim()))
625 .map(|(i, stride)| i * stride)
626 .sum();
627
628 self.access.write_value_at(offset, value)
629 }
630}
631
632pub trait NDArrayCast<OT: Number>: NDArray + Sized {
636 type Output: Access<OT>;
637
638 fn cast(self) -> Result<Array<OT, Self::Output, Self::Platform>, Error>;
640}
641
642impl<IT, OT, A, P> NDArrayCast<OT> for Array<IT, A, P>
643where
644 IT: Number,
645 OT: Number,
646 A: Access<IT>,
647 P: ElementwiseCast<A, IT, OT>,
648{
649 type Output = AccessOp<P::Op, P>;
650
651 fn cast(self) -> Result<Array<OT, AccessOp<P::Op, P>, P>, Error> {
652 Ok(Array {
653 shape: self.shape,
654 access: self.platform.cast(self.access)?,
655 platform: self.platform,
656 dtype: PhantomData,
657 })
658 }
659}
660
661pub trait NDArrayReduce<'a>: NDArray + fmt::Debug {
663 type Output: Access<Self::DType> + 'a;
664
665 fn max(
667 self,
668 axes: Axes,
669 keepdims: bool,
670 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
671 where
672 Self::DType: Real;
673
674 fn min(
676 self,
677 axes: Axes,
678 keepdims: bool,
679 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
680 where
681 Self::DType: Real;
682
683 fn product(
685 self,
686 axes: Axes,
687 keepdims: bool,
688 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
689
690 fn sum(
692 self,
693 axes: Axes,
694 keepdims: bool,
695 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
696}
697
698impl<'a, T, A, P> NDArrayReduce<'a> for Array<T, A, P>
699where
700 T: Number + 'a,
701 A: Access<T> + 'a,
702 P: Transform<A, T> + ReduceAxes<Accessor<'a, T>, T> + 'a,
703 Accessor<'a, T>: From<A> + From<AccessOp<P::Transpose, P>> + 'a,
704{
705 type Output = AccessOp<P::Op, P>;
706
707 fn max(
708 self,
709 axes: Axes,
710 keepdims: bool,
711 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
712 where
713 T: Real,
714 {
715 self.reduce_axes(axes, keepdims, |platform, access, stride| {
716 ReduceAxes::max(platform, access, stride)
717 })
718 }
719
720 fn min(
721 self,
722 axes: Axes,
723 keepdims: bool,
724 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
725 where
726 T: Real,
727 {
728 self.reduce_axes(axes, keepdims, |platform, access, stride| {
729 ReduceAxes::min(platform, access, stride)
730 })
731 }
732
733 fn product(
734 self,
735 axes: Axes,
736 keepdims: bool,
737 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
738 self.reduce_axes(axes, keepdims, |platform, access, stride| {
739 ReduceAxes::product(platform, access, stride)
740 })
741 }
742
743 fn sum(
744 self,
745 axes: Axes,
746 keepdims: bool,
747 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
748 self.reduce_axes(axes, keepdims, |platform, access, stride| {
749 ReduceAxes::sum(platform, access, stride)
750 })
751 }
752}
753
754pub trait NDArrayTransform: NDArray + Sized + fmt::Debug {
756 type Broadcast: Access<Self::DType>;
758
759 type Flip: Access<Self::DType>;
761
762 type Slice: Access<Self::DType>;
764
765 type Transpose: Access<Self::DType>;
767
768 fn broadcast(
770 self,
771 shape: Shape,
772 ) -> Result<Array<Self::DType, Self::Broadcast, Self::Platform>, Error>;
773
774 fn flip(self, axis: usize) -> Result<Array<Self::DType, Self::Flip, Self::Platform>, Error>;
775
776 fn reshape(self, shape: Shape) -> Result<Self, Error>;
778
779 fn slice(self, range: Range) -> Result<Array<Self::DType, Self::Slice, Self::Platform>, Error>;
781
782 fn squeeze(self, axes: Axes) -> Result<Self, Error>;
785
786 fn unsqueeze(self, axes: Axes) -> Result<Self, Error>;
788
789 fn transpose<P: Into<Option<Axes>>>(
792 self,
793 permutation: P,
794 ) -> Result<Array<Self::DType, Self::Transpose, Self::Platform>, Error>;
795}
796
797impl<T, A, P> NDArrayTransform for Array<T, A, P>
798where
799 T: Number,
800 A: Access<T>,
801 P: Transform<A, T>,
802{
803 type Broadcast = AccessOp<P::Broadcast, P>;
804 type Flip = AccessOp<P::Flip, P>;
805 type Slice = AccessOp<P::Slice, P>;
806 type Transpose = AccessOp<P::Transpose, P>;
807
808 fn broadcast(self, shape: Shape) -> Result<Array<T, AccessOp<P::Broadcast, P>, P>, Error> {
809 if !can_broadcast(self.shape(), &shape) {
810 return Err(Error::bounds(format!(
811 "cannot broadcast {self:?} into {shape:?}"
812 )));
813 }
814
815 let platform = P::select(shape.iter().product());
816 let broadcast = Shape::from_slice(&shape);
817 let access = platform.broadcast(self.access, self.shape, broadcast)?;
818
819 Ok(Array {
820 shape,
821 access,
822 platform,
823 dtype: self.dtype,
824 })
825 }
826
827 fn flip(self, axis: usize) -> Result<Array<T, AccessOp<P::Flip, P>, P>, Error> {
828 let platform = self.platform;
829 let access = platform.flip(self.access, self.shape.clone(), axis)?;
830
831 Ok(Array {
832 shape: self.shape,
833 access,
834 platform,
835 dtype: self.dtype,
836 })
837 }
838
839 fn reshape(mut self, shape: Shape) -> Result<Self, Error> {
840 if shape.iter().product::<usize>() == self.size() {
841 self.shape = shape;
842 Ok(self)
843 } else {
844 Err(Error::bounds(format!(
845 "cannot reshape an array with shape {:?} into {shape:?}",
846 self.shape
847 )))
848 }
849 }
850
851 fn slice(self, mut range: Range) -> Result<Array<T, AccessOp<P::Slice, P>, P>, Error> {
852 for (dim, range) in self.shape.iter().zip(&range) {
853 match range {
854 AxisRange::At(i) if i < dim => Ok(()),
855 AxisRange::In(start, stop, _step) if start < dim && stop <= dim => Ok(()),
856 AxisRange::Of(indices) if indices.iter().all(|i| i < dim) => Ok(()),
857 range => Err(Error::bounds(format!(
858 "invalid range {range:?} for dimension {dim}"
859 ))),
860 }?;
861 }
862
863 for dim in self.shape.iter().skip(range.len()).copied() {
864 range.push(AxisRange::In(0, dim, 1));
865 }
866
867 let shape = range_shape(self.shape(), &range);
868 let access = self.platform.slice(self.access, &self.shape, range)?;
869 let platform = P::select(shape.iter().product());
870
871 Ok(Array {
872 shape,
873 access,
874 platform,
875 dtype: self.dtype,
876 })
877 }
878
879 fn squeeze(mut self, mut axes: Axes) -> Result<Self, Error> {
880 axes.sort();
881
882 for x in axes.into_iter().rev() {
883 if x < self.shape.len() {
884 self.shape.remove(x);
885 } else {
886 return Err(Error::bounds(format!("axis out of bounds: {x}")));
887 }
888 }
889
890 Ok(self)
891 }
892
893 fn unsqueeze(mut self, axes: Axes) -> Result<Self, Error> {
894 for x in axes {
895 if x <= self.shape.len() {
896 self.shape.insert(x, 1);
897 } else {
898 return Err(Error::bounds(format!("axis out of bounds: {x}")));
899 }
900 }
901
902 Ok(self)
903 }
904
905 fn transpose<PA: Into<Option<Axes>>>(
906 self,
907 permutation: PA,
908 ) -> Result<Array<T, AccessOp<P::Transpose, P>, P>, Error> {
909 let permutation = if let Some(axes) = permutation.into() {
910 if axes.len() == self.ndim()
911 && axes.iter().copied().all(|x| x < self.ndim())
912 && !(1..axes.len()).any(|i| axes[i..].contains(&axes[i - 1]))
913 {
914 Ok(axes)
915 } else {
916 Err(Error::bounds(format!(
917 "invalid permutation for shape {:?}: {:?}",
918 self.shape, axes
919 )))
920 }
921 } else {
922 Ok((0..self.ndim()).rev().collect())
923 }?;
924
925 let shape = permutation.iter().copied().map(|x| self.shape[x]).collect();
926 let platform = self.platform;
927 let access = platform.transpose(self.access, self.shape, permutation)?;
928
929 Ok(Array {
930 shape,
931 access,
932 platform,
933 dtype: self.dtype,
934 })
935 }
936}
937
938pub trait NDArrayUnary: NDArray + Sized {
940 type Output: Access<Self::DType>;
942
943 fn exp(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
945
946 fn ln(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
948
949 fn round(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
951 where
952 Self::DType: Real;
953}
954
955impl<T, A, P> NDArrayUnary for Array<T, A, P>
956where
957 T: Float,
958 A: Access<T>,
959 P: ElementwiseUnary<A, T>,
960{
961 type Output = AccessOp<P::Op, P>;
962
963 fn exp(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
964 self.apply(|platform, access| platform.exp(access))
965 }
966
967 fn ln(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
968 where
969 P: ElementwiseUnary<A, T>,
970 {
971 self.apply(|platform, access| platform.ln(access))
972 }
973
974 fn round(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
975 where
976 T: Real,
977 {
978 self.apply(|platform, access| platform.round(access))
979 }
980}
981
982pub trait NDArrayUnaryBoolean: NDArray + Sized {
984 type Output: Access<u8>;
986
987 fn not(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
989}
990
991impl<T, A, P> NDArrayUnaryBoolean for Array<T, A, P>
992where
993 T: Number,
994 A: Access<T>,
995 P: ElementwiseUnaryBoolean<A, T>,
996{
997 type Output = AccessOp<P::Op, P>;
998
999 fn not(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1000 self.apply(|platform, access| platform.not(access))
1001 }
1002}
1003
1004pub trait NDArrayBoolean<O>: NDArray + Sized
1006where
1007 O: NDArray<DType = Self::DType>,
1008{
1009 type Output: Access<u8>;
1010
1011 fn and(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1013
1014 fn or(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1016
1017 fn xor(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1019}
1020
1021impl<T, L, R, P> NDArrayBoolean<Array<T, R, P>> for Array<T, L, P>
1022where
1023 T: Number,
1024 L: Access<T>,
1025 R: Access<T>,
1026 P: ElementwiseBoolean<L, R, T>,
1027{
1028 type Output = AccessOp<P::Op, P>;
1029
1030 fn and(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1031 same_shape("and", self.shape(), other.shape())?;
1032 self.apply_dual(other, |platform, left, right| platform.and(left, right))
1033 }
1034
1035 fn or(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1036 same_shape("or", self.shape(), other.shape())?;
1037 self.apply_dual(other, |platform, left, right| platform.or(left, right))
1038 }
1039
1040 fn xor(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1041 same_shape("xor", self.shape(), other.shape())?;
1042 self.apply_dual(other, |platform, left, right| platform.xor(left, right))
1043 }
1044}
1045
1046pub trait NDArrayBooleanScalar: NDArray + Sized {
1048 type Output: Access<u8>;
1049
1050 fn and_scalar(
1052 self,
1053 other: Self::DType,
1054 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1055
1056 fn or_scalar(
1058 self,
1059 other: Self::DType,
1060 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1061
1062 fn xor_scalar(
1064 self,
1065 other: Self::DType,
1066 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1067}
1068
1069impl<T, A, P> NDArrayBooleanScalar for Array<T, A, P>
1070where
1071 T: Number,
1072 A: Access<T>,
1073 P: ElementwiseBooleanScalar<A, T>,
1074{
1075 type Output = AccessOp<P::Op, P>;
1076
1077 fn and_scalar(
1078 self,
1079 other: Self::DType,
1080 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1081 self.apply(|platform, access| platform.and_scalar(access, other))
1082 }
1083
1084 fn or_scalar(
1085 self,
1086 other: Self::DType,
1087 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1088 self.apply(|platform, access| platform.or_scalar(access, other))
1089 }
1090
1091 fn xor_scalar(
1092 self,
1093 other: Self::DType,
1094 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1095 self.apply(|platform, access| platform.xor_scalar(access, other))
1096 }
1097}
1098
1099pub trait NDArrayCompare<O: NDArray<DType = Self::DType>>: NDArray + Sized {
1101 type Output: Access<u8>;
1102
1103 fn eq(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1105
1106 fn ge(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1108 where
1109 Self::DType: Real;
1110
1111 fn gt(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1113 where
1114 Self::DType: Real;
1115
1116 fn le(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1118 where
1119 Self::DType: Real;
1120
1121 fn lt(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1123 where
1124 Self::DType: Real;
1125
1126 fn ne(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1128}
1129
1130impl<T, L, R, P> NDArrayCompare<Array<T, R, P>> for Array<T, L, P>
1131where
1132 T: Number,
1133 L: Access<T>,
1134 R: Access<T>,
1135 P: ElementwiseCompare<L, R, T>,
1136{
1137 type Output = AccessOp<P::Op, P>;
1138
1139 fn eq(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1140 same_shape("compare", self.shape(), other.shape())?;
1141 self.apply_dual(other, |platform, left, right| platform.eq(left, right))
1142 }
1143
1144 fn ge(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1145 where
1146 T: Real,
1147 {
1148 same_shape("compare", self.shape(), other.shape())?;
1149 self.apply_dual(other, |platform, left, right| platform.ge(left, right))
1150 }
1151
1152 fn gt(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1153 where
1154 T: Real,
1155 {
1156 same_shape("compare", self.shape(), other.shape())?;
1157 self.apply_dual(other, |platform, left, right| platform.gt(left, right))
1158 }
1159
1160 fn le(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1161 where
1162 T: Real,
1163 {
1164 same_shape("compare", self.shape(), other.shape())?;
1165 self.apply_dual(other, |platform, left, right| platform.le(left, right))
1166 }
1167
1168 fn lt(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1169 where
1170 T: Real,
1171 {
1172 same_shape("compare", self.shape(), other.shape())?;
1173 self.apply_dual(other, |platform, left, right| platform.lt(left, right))
1174 }
1175
1176 fn ne(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1177 same_shape("compare", self.shape(), other.shape())?;
1178 self.apply_dual(other, |platform, left, right| platform.ne(left, right))
1179 }
1180}
1181
1182pub trait NDArrayCompareScalar: NDArray + Sized {
1184 type Output: Access<u8>;
1185
1186 fn eq_scalar(
1188 self,
1189 other: Self::DType,
1190 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1191
1192 fn gt_scalar(
1194 self,
1195 other: Self::DType,
1196 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1197 where
1198 Self::DType: Real;
1199
1200 fn ge_scalar(
1202 self,
1203 other: Self::DType,
1204 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1205 where
1206 Self::DType: Real;
1207
1208 fn lt_scalar(
1210 self,
1211 other: Self::DType,
1212 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1213 where
1214 Self::DType: Real;
1215
1216 fn le_scalar(
1218 self,
1219 other: Self::DType,
1220 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1221 where
1222 Self::DType: Real;
1223
1224 fn ne_scalar(
1226 self,
1227 other: Self::DType,
1228 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1229}
1230
1231impl<T, A, P> NDArrayCompareScalar for Array<T, A, P>
1232where
1233 T: Number,
1234 A: Access<T>,
1235 P: ElementwiseCompareScalar<A, T>,
1236{
1237 type Output = AccessOp<P::Op, P>;
1238
1239 fn eq_scalar(
1240 self,
1241 other: Self::DType,
1242 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1243 self.apply(|platform, access| platform.eq_scalar(access, other))
1244 }
1245
1246 fn gt_scalar(self, other: Self::DType) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1247 where
1248 T: Real,
1249 {
1250 self.apply(|platform, access| platform.gt_scalar(access, other))
1251 }
1252
1253 fn ge_scalar(self, other: Self::DType) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1254 where
1255 T: Real,
1256 {
1257 self.apply(|platform, access| platform.ge_scalar(access, other))
1258 }
1259
1260 fn lt_scalar(self, other: Self::DType) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1261 where
1262 T: Real,
1263 {
1264 self.apply(|platform, access| platform.lt_scalar(access, other))
1265 }
1266
1267 fn le_scalar(self, other: Self::DType) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1268 where
1269 T: Real,
1270 {
1271 self.apply(|platform, access| platform.le_scalar(access, other))
1272 }
1273
1274 fn ne_scalar(
1275 self,
1276 other: Self::DType,
1277 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1278 self.apply(|platform, access| platform.ne_scalar(access, other))
1279 }
1280}
1281
1282#[cfg(feature = "complex")]
1283pub trait NDArrayComplex: NDArray + Sized
1285where
1286 Self::DType: Complex,
1287{
1288 type Real: Access<<Self::DType as Complex>::Real>;
1289 type Complex: Access<Self::DType>;
1290
1291 fn angle(
1293 self,
1294 ) -> Result<Array<<Self::DType as Complex>::Real, Self::Real, Self::Platform>, Error>;
1295
1296 fn conj(self) -> Result<Array<Self::DType, Self::Complex, Self::Platform>, Error>;
1298
1299 fn re(self) -> Result<Array<Self::DType, Self::Real, Self::Platform>, Error>;
1301
1302 fn im(self) -> Result<Array<Self::DType, Self::Real, Self::Platform>, Error>;
1304}
1305
1306#[cfg(feature = "complex")]
1307impl<T, A, P> NDArrayComplex for Array<T, A, P>
1308where
1309 T: Complex,
1310 A: Access<T>,
1311 P: complex::ElementwiseUnaryComplex<A, T>,
1312{
1313 type Real = AccessOp<P::Real, P>;
1314 type Complex = AccessOp<P::Complex, P>;
1315
1316 fn angle(self) -> Result<Array<T::Real, Self::Real, Self::Platform>, Error> {
1317 self.apply(|platform, access| platform.angle(access))
1318 }
1319
1320 fn conj(self) -> Result<Array<Self::DType, Self::Complex, Self::Platform>, Error> {
1321 self.apply(|platform, access| platform.conj(access))
1322 }
1323
1324 fn re(self) -> Result<Array<Self::DType, Self::Real, Self::Platform>, Error> {
1325 self.apply(|platform, access| platform.re(access))
1326 }
1327
1328 fn im(self) -> Result<Array<Self::DType, Self::Real, Self::Platform>, Error> {
1329 self.apply(|platform, access| platform.im(access))
1330 }
1331}
1332
1333#[cfg(feature = "complex")]
1334pub trait NDArrayFourier: NDArray + Sized
1336where
1337 Self::DType: Complex,
1338{
1339 type Output: Access<Self::DType>;
1340
1341 fn fft(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1343
1344 fn ifft(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1346}
1347
1348#[cfg(feature = "complex")]
1349impl<A, T, P> NDArrayFourier for Array<num_complex::Complex<T>, A, P>
1350where
1351 A: Access<num_complex::Complex<T>>,
1352 num_complex::Complex<T>: Complex,
1353 P: complex::Fourier<A, num_complex::Complex<T>>,
1354{
1355 type Output = AccessOp<P::Op, P>;
1356
1357 fn fft(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1358 let dim = self
1359 .shape
1360 .last()
1361 .copied()
1362 .ok_or_else(|| Error::bounds("a scalar value has no Fourier transform".into()))?;
1363
1364 self.apply(|platform, access| platform.fft(access, dim))
1365 }
1366
1367 fn ifft(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1368 let dim = self
1369 .shape
1370 .last()
1371 .copied()
1372 .ok_or_else(|| Error::bounds("a scalar value has no Fourier transform".into()))?;
1373
1374 self.apply(|platform, access| platform.ifft(access, dim))
1375 }
1376}
1377
1378pub trait NDArrayMath<O: NDArray<DType = Self::DType>>: NDArray + Sized {
1381 type Output: Access<Self::DType>;
1382
1383 fn add(self, rhs: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1385
1386 fn div(self, rhs: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1388
1389 fn log(self, base: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
1391 where
1392 Self::DType: Float;
1393
1394 fn mul(self, rhs: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1396
1397 fn pow(self, exp: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1399
1400 fn sub(self, rhs: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1402
1403 fn rem(self, rhs: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
1405 where
1406 Self::DType: Real;
1407}
1408
1409impl<T, L, R, P> NDArrayMath<Array<T, R, P>> for Array<T, L, P>
1410where
1411 T: Number,
1412 L: Access<T>,
1413 R: Access<T>,
1414 P: ElementwiseDual<L, R, T>,
1415{
1416 type Output = AccessOp<P::Op, P>;
1417
1418 fn add(
1419 self,
1420 rhs: Array<T, R, P>,
1421 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1422 same_shape("add", self.shape(), rhs.shape())?;
1423 self.apply_dual(rhs, |platform, left, right| platform.add(left, right))
1424 }
1425
1426 fn div(
1427 self,
1428 rhs: Array<T, R, P>,
1429 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1430 same_shape("divide", self.shape(), rhs.shape())?;
1431 self.apply_dual(rhs, |platform, left, right| platform.div(left, right))
1432 }
1433
1434 fn log(
1435 self,
1436 base: Array<T, R, P>,
1437 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
1438 where
1439 T: Float,
1440 {
1441 same_shape("log", self.shape(), base.shape())?;
1442 self.apply_dual(base, |platform, left, right| platform.log(left, right))
1443 }
1444
1445 fn mul(
1446 self,
1447 rhs: Array<T, R, P>,
1448 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1449 same_shape("multiply", self.shape(), rhs.shape())?;
1450 self.apply_dual(rhs, |platform, left, right| platform.mul(left, right))
1451 }
1452
1453 fn pow(
1454 self,
1455 exp: Array<T, R, P>,
1456 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1457 same_shape("exponentiate", self.shape(), exp.shape())?;
1458 self.apply_dual(exp, |platform, left, right| platform.pow(left, right))
1459 }
1460
1461 fn sub(
1462 self,
1463 rhs: Array<T, R, P>,
1464 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1465 same_shape("subtract", self.shape(), rhs.shape())?;
1466 self.apply_dual(rhs, |platform, left, right| platform.sub(left, right))
1467 }
1468
1469 fn rem(
1470 self,
1471 rhs: Array<T, R, P>,
1472 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
1473 where
1474 T: Real,
1475 {
1476 same_shape("remainder", self.shape(), rhs.shape())?;
1477 self.apply_dual(rhs, |platform, left, right| platform.rem(left, right))
1478 }
1479}
1480
1481pub trait NDArrayMathScalar: NDArray + Sized {
1483 type Output: Access<Self::DType>;
1484
1485 fn add_scalar(
1487 self,
1488 rhs: Self::DType,
1489 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1490
1491 fn div_scalar(
1493 self,
1494 rhs: Self::DType,
1495 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1496
1497 fn log_scalar(
1499 self,
1500 base: Self::DType,
1501 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
1502 where
1503 Self::DType: Float;
1504
1505 fn mul_scalar(
1507 self,
1508 rhs: Self::DType,
1509 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1510
1511 fn pow_scalar(
1513 self,
1514 exp: Self::DType,
1515 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1516
1517 fn rem_scalar(
1519 self,
1520 rhs: Self::DType,
1521 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
1522 where
1523 Self::DType: Real;
1524
1525 fn sub_scalar(
1527 self,
1528 rhs: Self::DType,
1529 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1530}
1531
1532impl<T, A, P> NDArrayMathScalar for Array<T, A, P>
1533where
1534 T: Number,
1535 A: Access<T>,
1536 P: ElementwiseScalar<A, T>,
1537{
1538 type Output = AccessOp<P::Op, P>;
1539
1540 fn add_scalar(
1541 self,
1542 rhs: Self::DType,
1543 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1544 self.apply(|platform, left| platform.add_scalar(left, rhs))
1545 }
1546
1547 fn div_scalar(
1548 self,
1549 rhs: Self::DType,
1550 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1551 if rhs == T::ZERO {
1552 Err(Error::unsupported(format!(
1553 "cannot divide {self:?} by {rhs}"
1554 )))
1555 } else {
1556 self.apply(|platform, left| platform.div_scalar(left, rhs))
1557 }
1558 }
1559
1560 fn log_scalar(
1561 self,
1562 base: Self::DType,
1563 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
1564 where
1565 Self::DType: Float,
1566 {
1567 self.apply(|platform, arg| platform.log_scalar(arg, base))
1568 }
1569
1570 fn mul_scalar(
1571 self,
1572 rhs: Self::DType,
1573 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1574 self.apply(|platform, left| platform.mul_scalar(left, rhs))
1575 }
1576
1577 fn pow_scalar(
1578 self,
1579 exp: Self::DType,
1580 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1581 self.apply(|platform, arg| platform.pow_scalar(arg, exp))
1582 }
1583
1584 fn rem_scalar(
1585 self,
1586 rhs: Self::DType,
1587 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
1588 where
1589 Self::DType: Real,
1590 {
1591 self.apply(|platform, left| platform.rem_scalar(left, rhs))
1592 }
1593
1594 fn sub_scalar(
1595 self,
1596 rhs: Self::DType,
1597 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1598 self.apply(|platform, left| platform.sub_scalar(left, rhs))
1599 }
1600}
1601
1602pub trait NDArrayNumeric: NDArray + Sized
1604where
1605 Self::DType: Float,
1606{
1607 type Output: Access<u8>;
1608
1609 #[allow(clippy::wrong_self_convention)]
1611 fn is_inf(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1613
1614 #[allow(clippy::wrong_self_convention)]
1616 fn is_nan(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1618}
1619
1620impl<T, A, P> NDArrayNumeric for Array<T, A, P>
1621where
1622 T: Float,
1623 A: Access<T>,
1624 P: ElementwiseNumeric<A, T>,
1625{
1626 type Output = AccessOp<P::Op, P>;
1627
1628 fn is_inf(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1629 self.apply(|platform, access| platform.is_inf(access))
1630 }
1631
1632 fn is_nan(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1633 self.apply(|platform, access| platform.is_nan(access))
1634 }
1635}
1636
1637pub trait NDArrayReduceBoolean: NDArrayRead {
1639 fn all(self) -> Result<bool, Error>;
1641
1642 fn any(self) -> Result<bool, Error>;
1644}
1645
1646impl<T, A, P> NDArrayReduceBoolean for Array<T, A, P>
1647where
1648 T: Number,
1649 A: Access<T>,
1650 P: ReduceAll<A, T>,
1651{
1652 fn all(self) -> Result<bool, Error> {
1653 self.platform.all(self.access)
1654 }
1655
1656 fn any(self) -> Result<bool, Error> {
1657 self.platform.any(self.access)
1658 }
1659}
1660
1661pub trait NDArrayReduceAll: NDArrayRead {
1663 fn max_all(self) -> Result<Self::DType, Error>
1665 where
1666 Self::DType: Real;
1667
1668 fn min_all(self) -> Result<Self::DType, Error>
1670 where
1671 Self::DType: Real;
1672
1673 fn product_all(self) -> Result<Self::DType, Error>;
1675
1676 fn sum_all(self) -> Result<Self::DType, Error>;
1678}
1679
1680impl<T, A, P> NDArrayReduceAll for Array<T, A, P>
1681where
1682 T: Number,
1683 A: Access<T>,
1684 P: ReduceAll<A, T>,
1685{
1686 fn max_all(self) -> Result<Self::DType, Error>
1687 where
1688 T: Real,
1689 {
1690 self.platform.max(self.access)
1691 }
1692
1693 fn min_all(self) -> Result<Self::DType, Error>
1694 where
1695 T: Real,
1696 {
1697 self.platform.min(self.access)
1698 }
1699
1700 fn product_all(self) -> Result<Self::DType, Error> {
1701 self.platform.product(self.access)
1702 }
1703
1704 fn sum_all(self) -> Result<T, Error> {
1705 self.platform.sum(self.access)
1706 }
1707}
1708
1709impl<T, A, P> fmt::Debug for Array<T, A, P> {
1710 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1711 write!(
1712 f,
1713 "a {} array of shape {:?}",
1714 std::any::type_name::<T>(),
1715 self.shape
1716 )
1717 }
1718}
1719
1720pub trait NDArrayTrig: NDArray + Sized {
1722 type Output: Access<Self::DType>;
1723
1724 fn sin(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1726
1727 fn asin(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1729
1730 fn sinh(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1732
1733 fn cos(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1735
1736 fn acos(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1738
1739 fn cosh(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1741
1742 fn tan(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1744
1745 fn atan(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1747
1748 fn tanh(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1750}
1751
1752impl<T, A, P> NDArrayTrig for Array<T, A, P>
1753where
1754 T: Float,
1755 A: Access<T>,
1756 P: ElementwiseTrig<A, T>,
1757{
1758 type Output = AccessOp<P::Op, P>;
1759
1760 fn sin(self) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1761 self.apply(|platform, access| platform.sin(access))
1762 }
1763
1764 fn asin(self) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1765 self.apply(|platform, access| platform.asin(access))
1766 }
1767
1768 fn sinh(self) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1769 self.apply(|platform, access| platform.sinh(access))
1770 }
1771
1772 fn cos(self) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1773 self.apply(|platform, access| platform.cos(access))
1774 }
1775
1776 fn acos(self) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1777 self.apply(|platform, access| platform.acos(access))
1778 }
1779
1780 fn cosh(self) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1781 self.apply(|platform, access| platform.cosh(access))
1782 }
1783
1784 fn tan(self) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1785 self.apply(|platform, access| platform.tan(access))
1786 }
1787
1788 fn atan(self) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1789 self.apply(|platform, access| platform.atan(access))
1790 }
1791
1792 fn tanh(self) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1793 self.apply(|platform, access| platform.tanh(access))
1794 }
1795}
1796
1797pub trait NDArrayWhere<T, L, R>: NDArray<DType = u8> + fmt::Debug
1799where
1800 T: Number,
1801{
1802 type Output: Access<T>;
1803
1804 fn cond(self, then: L, or_else: R) -> Result<Array<T, Self::Output, Self::Platform>, Error>;
1808}
1809
1810impl<T, A, L, R, P> NDArrayWhere<T, Array<T, L, P>, Array<T, R, P>> for Array<u8, A, P>
1811where
1812 T: Number,
1813 A: Access<u8>,
1814 L: Access<T>,
1815 R: Access<T>,
1816 P: GatherCond<A, L, R, T>,
1817{
1818 type Output = AccessOp<P::Op, P>;
1819
1820 fn cond(
1821 self,
1822 then: Array<T, L, P>,
1823 or_else: Array<T, R, P>,
1824 ) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1825 same_shape("cond", self.shape(), then.shape())?;
1826 same_shape("cond", self.shape(), or_else.shape())?;
1827
1828 let access = self
1829 .platform
1830 .cond(self.access, then.access, or_else.access)?;
1831
1832 Ok(Array {
1833 shape: self.shape,
1834 access,
1835 platform: self.platform,
1836 dtype: PhantomData,
1837 })
1838 }
1839}
1840
1841pub trait MatrixDual<O>: NDArray + fmt::Debug
1843where
1844 O: NDArray<DType = Self::DType> + fmt::Debug,
1845{
1846 type Output: Access<Self::DType>;
1847
1848 fn matmul(self, other: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1850}
1851
1852impl<T, L, R, P> MatrixDual<Array<T, R, P>> for Array<T, L, P>
1853where
1854 T: Number,
1855 L: Access<T>,
1856 R: Access<T>,
1857 P: LinAlgDual<L, R, T>,
1858{
1859 type Output = AccessOp<P::Op, P>;
1860
1861 fn matmul(
1862 self,
1863 other: Array<T, R, P>,
1864 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1865 let dims = matmul_dims(&self.shape, &other.shape).ok_or_else(|| {
1866 Error::bounds(format!(
1867 "invalid dimensions for matrix multiply: {:?} and {:?}",
1868 self.shape, other.shape
1869 ))
1870 })?;
1871
1872 let mut shape = Shape::with_capacity(self.ndim());
1873 shape.extend(self.shape.iter().rev().skip(2).rev().copied());
1874 shape.push(dims[1]);
1875 shape.push(dims[3]);
1876
1877 let platform = P::select(dims.iter().product());
1878
1879 let access = platform.matmul(self.access, other.access, dims)?;
1880
1881 Ok(Array {
1882 shape,
1883 access,
1884 platform,
1885 dtype: self.dtype,
1886 })
1887 }
1888}
1889
1890pub trait MatrixUnary: NDArray + fmt::Debug {
1892 type Diag: Access<Self::DType>;
1893 type Transpose: Access<Self::DType>;
1894
1895 fn mt(self) -> Result<Array<Self::DType, Self::Transpose, Self::Platform>, Error>;
1897
1898 fn diag(self) -> Result<Array<Self::DType, Self::Diag, Self::Platform>, Error>;
1901}
1902
1903impl<T, A, P> MatrixUnary for Array<T, A, P>
1904where
1905 T: Number,
1906 A: Access<T>,
1907 P: LinAlgUnary<A, T> + Transform<A, T>,
1908{
1909 type Diag = AccessOp<<P as LinAlgUnary<A, T>>::Op, P>;
1910 type Transpose = AccessOp<<P as Transform<A, T>>::Transpose, P>;
1911
1912 fn mt(self) -> Result<Array<Self::DType, Self::Transpose, Self::Platform>, Error> {
1913 let ndim = self.ndim();
1914 let mut permutation = Axes::with_capacity(ndim);
1915 permutation.extend(0..self.ndim() - 2);
1916 permutation.push(ndim - 1);
1917 permutation.push(ndim - 2);
1918 self.transpose(permutation)
1919 }
1920
1921 fn diag(self) -> Result<Array<T, AccessOp<P::Op, P>, P>, Error> {
1922 if self.ndim() >= 2 && self.shape.last() == self.shape.iter().nth_back(1) {
1923 let batch_size = self.shape.iter().rev().skip(2).product();
1924 let dim = self.shape.last().copied().expect("dim");
1925
1926 let shape = self.shape.iter().rev().skip(1).rev().copied().collect();
1927 let platform = P::select(batch_size * dim * dim);
1928 let access = platform.diag(self.access, batch_size, dim)?;
1929
1930 Ok(Array {
1931 shape,
1932 access,
1933 platform,
1934 dtype: PhantomData,
1935 })
1936 } else {
1937 Err(Error::bounds(format!(
1938 "invalid shape for diagonal: {:?}",
1939 self.shape
1940 )))
1941 }
1942 }
1943}
1944
1945#[cfg(feature = "complex")]
1946pub trait MatrixUnaryComplex: MatrixUnary
1948where
1949 Self::DType: Complex,
1950{
1951 type Hermitian: Access<Self::DType>;
1952
1953 fn mh(self) -> Result<Array<Self::DType, Self::Hermitian, Self::Platform>, Error>;
1955}
1956
1957#[cfg(feature = "complex")]
1958impl<T, A, P> MatrixUnaryComplex for Array<T, A, P>
1959where
1960 T: Complex,
1961 A: Access<T>,
1962 P: complex::ElementwiseUnaryComplex<Self::Transpose, T> + LinAlgUnary<A, T> + Transform<A, T>,
1963{
1964 type Hermitian = AccessOp<P::Complex, P>;
1965
1966 fn mh(self) -> Result<Array<Self::DType, Self::Hermitian, Self::Platform>, Error> {
1967 self.mt().and_then(|array| array.conj())
1968 }
1969}
1970
1971#[inline]
1972fn can_broadcast(left: &[usize], right: &[usize]) -> bool {
1973 if left.len() < right.len() {
1974 return can_broadcast(right, left);
1975 }
1976
1977 for (l, r) in left.iter().copied().rev().zip(right.iter().copied().rev()) {
1978 if l == r || l == 1 || r == 1 {
1979 } else {
1981 return false;
1982 }
1983 }
1984
1985 true
1986}
1987
1988#[inline]
1989fn matmul_dims(left: &[usize], right: &[usize]) -> Option<[usize; 4]> {
1990 let mut left = left.iter().copied().rev();
1991 let mut right = right.iter().copied().rev();
1992
1993 let b = left.next()?;
1994 let a = left.next()?;
1995
1996 let c = right.next()?;
1997 if right.next()? != b {
1998 return None;
1999 }
2000
2001 let mut batch_size = 1;
2002 loop {
2003 match (left.next(), right.next()) {
2004 (Some(l), Some(r)) if l == r => {
2005 batch_size *= l;
2006 }
2007 (None, None) => break,
2008 _ => return None,
2009 }
2010 }
2011
2012 Some([batch_size, a, b, c])
2013}
2014
2015#[inline]
2016fn permute_for_reduce<'a, T, A, P>(
2017 platform: P,
2018 access: A,
2019 shape: Shape,
2020 axes: Axes,
2021) -> Result<Accessor<'a, T>, Error>
2022where
2023 T: Number,
2024 A: Access<T>,
2025 P: Transform<A, T>,
2026 Accessor<'a, T>: From<A> + From<AccessOp<P::Transpose, P>>,
2027{
2028 let mut permutation = Axes::with_capacity(shape.len());
2029 permutation.extend((0..shape.len()).filter(|x| !axes.contains(x)));
2030 permutation.extend(axes);
2031
2032 if permutation.iter().copied().enumerate().all(|(i, x)| i == x) {
2033 Ok(Accessor::from(access))
2034 } else {
2035 platform
2036 .transpose(access, shape, permutation)
2037 .map(Accessor::from)
2038 }
2039}
2040
2041#[inline]
2042fn reduce_axes(shape: &[usize], axes: &[usize], keepdims: bool) -> Result<Shape, Error> {
2043 let mut shape = Shape::from_slice(shape);
2044
2045 for x in axes.iter().copied().rev() {
2046 if x >= shape.len() {
2047 return Err(Error::bounds(format!(
2048 "axis {x} is out of bounds for {shape:?}"
2049 )));
2050 } else if keepdims {
2051 shape[x] = 1;
2052 } else {
2053 shape.remove(x);
2054 }
2055 }
2056
2057 if shape.is_empty() {
2058 Ok(shape![1])
2059 } else {
2060 Ok(shape)
2061 }
2062}
2063
2064#[inline]
2065pub fn same_shape(op_name: &'static str, left: &[usize], right: &[usize]) -> Result<(), Error> {
2066 if left == right {
2067 Ok(())
2068 } else if can_broadcast(left, right) {
2069 Err(Error::bounds(format!(
2070 "cannot {op_name} arrays with shapes {left:?} and {right:?} (consider broadcasting)"
2071 )))
2072 } else {
2073 Err(Error::bounds(format!(
2074 "cannot {op_name} arrays with shapes {left:?} and {right:?}"
2075 )))
2076 }
2077}
2078
2079#[inline]
2080fn valid_coord(coord: &[usize], shape: &[usize]) -> Result<(), Error> {
2081 if coord.len() == shape.len() && coord.iter().zip(shape).all(|(i, dim)| i < dim) {
2082 return Ok(());
2083 }
2084
2085 Err(Error::bounds(format!(
2086 "invalid coordinate {coord:?} for shape {shape:?}"
2087 )))
2088}