Skip to main content

ha_ndarray/
array.rs

1#![allow(clippy::type_complexity)]
2// The public API is intentionally generic over access/platform types; explicit type aliases here
3// tend to obscure the actual bounds without meaningfully improving readability.
4
5use std::fmt;
6use std::marker::PhantomData;
7
8use crate::access::*;
9use crate::buffer::BufferInstance;
10use crate::ops::*;
11use crate::platform::PlatformInstance;
12#[cfg(feature = "complex")]
13use crate::Complex;
14use crate::{
15    axes, range_shape, shape, strides_for, ArrayAccess, Axes, AxisRange, BufferConverter, Constant,
16    Convert, Error, Float, Number, Platform, Range, Real, Shape,
17};
18
19pub struct Array<T, A, P> {
20    shape: Shape,
21    access: A,
22    platform: P,
23    dtype: PhantomData<T>,
24}
25
26impl<T, A: Clone, P: Clone> Clone for Array<T, A, P> {
27    fn clone(&self) -> Self {
28        Self {
29            shape: self.shape.clone(),
30            access: self.access.clone(),
31            platform: self.platform.clone(),
32            dtype: self.dtype,
33        }
34    }
35}
36
37impl<T, A, P> Array<T, A, P> {
38    fn apply<O, OT, Op>(self, op: Op) -> Result<Array<OT, AccessOp<O, P>, P>, Error>
39    where
40        P: Copy,
41        Op: Fn(P, A) -> Result<AccessOp<O, P>, Error>,
42    {
43        let access = (op)(self.platform, self.access)?;
44
45        Ok(Array {
46            shape: self.shape,
47            access,
48            platform: self.platform,
49            dtype: PhantomData,
50        })
51    }
52
53    fn reduce_axes<'a, Op>(
54        self,
55        mut axes: Axes,
56        keepdims: bool,
57        op: Op,
58    ) -> Result<Array<T, AccessOp<P::Op, P>, P>, Error>
59    where
60        T: Number,
61        A: Access<T>,
62        P: Transform<A, T> + ReduceAxes<Accessor<'a, T>, T>,
63        Op: Fn(P, Accessor<'a, T>, usize) -> Result<AccessOp<P::Op, P>, Error>,
64        Accessor<'a, T>: From<A> + From<AccessOp<P::Transpose, P>>,
65    {
66        axes.sort();
67        axes.dedup();
68
69        let platform = P::select(self.size());
70        let stride = axes.iter().copied().map(|x| self.shape[x]).product();
71        let shape = reduce_axes(&self.shape, &axes, keepdims)?;
72
73        let access = permute_for_reduce(self.platform, self.access, self.shape, axes)?;
74        let access = (op)(self.platform, access, stride)?;
75
76        Ok(Array {
77            access,
78            shape,
79            platform,
80            dtype: PhantomData,
81        })
82    }
83
84    pub fn access(&self) -> &A {
85        &self.access
86    }
87
88    pub fn into_access(self) -> A {
89        self.access
90    }
91}
92
93impl<T, L, P> Array<T, L, P> {
94    fn apply_dual<O, OT, R, Op>(
95        self,
96        other: Array<T, R, P>,
97        op: Op,
98    ) -> Result<Array<OT, AccessOp<O, P>, P>, Error>
99    where
100        P: Copy,
101        Op: Fn(P, L, R) -> Result<AccessOp<O, P>, Error>,
102    {
103        let access = (op)(self.platform, self.access, other.access)?;
104
105        Ok(Array {
106            shape: self.shape,
107            access,
108            platform: self.platform,
109            dtype: PhantomData,
110        })
111    }
112}
113
114// constructors
115impl<'a, T: Number> Array<T, Accessor<'a, T>, Platform> {
116    pub fn from<A, P>(array: Array<T, A, P>) -> Self
117    where
118        A: Into<Accessor<'a, T>>,
119        Platform: From<P>,
120    {
121        Self {
122            shape: array.shape,
123            access: array.access.into(),
124            platform: array.platform.into(),
125            dtype: array.dtype,
126        }
127    }
128}
129
130impl<T, B, P> Array<T, AccessBuf<B>, P>
131where
132    T: Number,
133    B: BufferInstance<T>,
134    P: PlatformInstance,
135{
136    fn new_inner(platform: P, buffer: B, shape: Shape) -> Result<Self, Error> {
137        if !shape.is_empty() && shape.iter().product::<usize>() == buffer.len() {
138            let access = buffer.into();
139
140            Ok(Self {
141                shape,
142                access,
143                platform,
144                dtype: PhantomData,
145            })
146        } else {
147            Err(Error::bounds(format!(
148                "cannot construct an array with shape {shape:?} from a buffer of size {}",
149                buffer.len(),
150            )))
151        }
152    }
153
154    pub fn convert<'a, FB>(buffer: FB, shape: Shape) -> Result<Self, Error>
155    where
156        FB: Into<BufferConverter<'a, T>>,
157        P: Convert<T, Buffer = B>,
158    {
159        let buffer = buffer.into();
160        let platform = P::select(buffer.len());
161        let buffer = platform.convert(buffer)?;
162        Self::new_inner(platform, buffer, shape)
163    }
164
165    pub fn new(buffer: B, shape: Shape) -> Result<Self, Error> {
166        let platform = P::select(buffer.len());
167        Self::new_inner(platform, buffer, shape)
168    }
169}
170
171impl<T, P> Array<T, AccessBuf<P::Buffer>, P>
172where
173    T: Number,
174    P: Constant<T>,
175{
176    pub fn constant(value: T, shape: Shape) -> Result<Self, Error> {
177        if !shape.is_empty() {
178            let size = shape.iter().product();
179            let platform = P::select(size);
180            let buffer = platform.constant(value, size)?;
181            let access = buffer.into();
182
183            Ok(Self {
184                shape,
185                access,
186                platform,
187                dtype: PhantomData,
188            })
189        } else {
190            Err(Error::bounds(
191                "cannot construct an array with an empty shape".to_string(),
192            ))
193        }
194    }
195}
196
197// copy constructors
198impl<T, A, P> Array<T, A, P>
199where
200    T: Number,
201    A: Access<T>,
202    P: Convert<T>,
203{
204    pub fn copy(&self) -> Result<Array<T, AccessBuf<P::Buffer>, P>, Error> {
205        let buffer = self.buffer().and_then(|buf| self.platform.convert(buf))?;
206
207        Ok(Array {
208            shape: self.shape.clone(),
209            access: buffer.into(),
210            platform: self.platform,
211            dtype: self.dtype,
212        })
213    }
214}
215
216// op constructors
217impl<T, A, P> Array<T, A, P>
218where
219    T: Number,
220    A: Access<T>,
221    P: Transform<A, T>,
222    P: ConstructConcat<AccessOp<<P as Transform<A, T>>::Transpose, P>, T>,
223    P: Transform<
224        AccessOp<<P as ConstructConcat<AccessOp<<P as Transform<A, T>>::Transpose, P>, T>>::Op, P>,
225        T,
226    >,
227{
228    pub fn stack<AS>(arrays: AS, axis: usize) -> Result<Array<T, impl Access<T>, P>, Error>
229    where
230        AS: IntoIterator<Item = Self>,
231    {
232        let arrays = arrays
233            .into_iter()
234            .map(|arr| arr.unsqueeze(axes![axis]))
235            .collect::<Result<Vec<_>, Error>>()?;
236
237        Array::transpose_concat(arrays, axis)
238    }
239
240    pub fn transpose_concat(
241        arrays: Vec<Self>,
242        axis: usize,
243    ) -> Result<Array<T, impl Access<T>, P>, Error> {
244        let shape = if let Some(first) = arrays.first() {
245            let shape = first.shape();
246            if axis < shape.len() {
247                Ok(shape)
248            } else {
249                Err(Error::bounds(format!("{first:?} has no axis {axis}")))
250            }
251        } else {
252            Err(Error::bounds(
253                "cannot concatenate an empty list of arrays".to_string(),
254            ))
255        }?;
256
257        for array in arrays.iter().skip(1) {
258            if array.ndim() == shape.len() {
259                for (x, (dim, a_dim)) in shape.iter().zip(array.shape()).enumerate() {
260                    if x != axis && dim != a_dim {
261                        return Err(Error::bounds(format!(
262                            "cannot concatenate {:?} with {:?} at axis {axis}",
263                            shape,
264                            array.shape()
265                        )));
266                    }
267                }
268            } else {
269                return Err(Error::bounds(format!(
270                    "cannot concatenate {:?} with {:?}",
271                    shape,
272                    array.shape()
273                )));
274            }
275        }
276
277        let mut permutation: Axes = (0..shape.len()).collect();
278        permutation.swap(0, axis);
279
280        let arrays = arrays
281            .into_iter()
282            .map(|array| array.transpose(permutation.clone()))
283            .collect::<Result<Vec<Array<T, _, P>>, Error>>()?;
284
285        Array::concat(arrays)?.transpose(permutation)
286    }
287}
288
289impl<T, A, P> Array<T, A, P>
290where
291    T: Number,
292    A: Access<T>,
293    P: ConstructConcat<A, T>,
294{
295    pub fn concat(arrays: Vec<Self>) -> Result<Array<T, AccessOp<P::Op, P>, P>, Error> {
296        let mut array_iter = arrays.iter();
297        let first = array_iter.next();
298
299        if let Some(first) = first {
300            let mut shape = Shape::from_slice(first.shape());
301            for next in array_iter {
302                if next.ndim() != shape.len()
303                    || (shape.len() > 1 && shape[1..] != next.shape()[1..])
304                {
305                    return Err(Error::bounds(format!(
306                        "cannot concatenate shapes {:?} and {:?}",
307                        shape,
308                        next.shape()
309                    )));
310                } else {
311                    shape[0] += next.shape()[0];
312                }
313            }
314
315            Self::concat_inner(arrays, shape)
316        } else {
317            Err(Error::bounds(
318                "cannot concatenate an empty list of arrays".into(),
319            ))
320        }
321    }
322
323    fn concat_inner(
324        arrays: Vec<Array<T, A, P>>,
325        shape: Shape,
326    ) -> Result<Array<T, AccessOp<P::Op, P>, P>, Error> {
327        let platform = P::select(shape.iter().product());
328
329        let data = arrays
330            .into_iter()
331            .map(|array| array.into_access())
332            .collect();
333
334        platform.concat(data).map(|access| Array {
335            shape,
336            access,
337            platform,
338            dtype: PhantomData,
339        })
340    }
341}
342
343impl<T: Number, P: PlatformInstance> Array<T, AccessOp<P::Range, P>, P>
344where
345    P: ConstructRange<T>,
346{
347    pub fn range(start: T, stop: T, shape: Shape) -> Result<Self, Error> {
348        let size = shape.iter().product();
349        let platform = P::select(size);
350
351        platform.range(start, stop, size).map(|access| Self {
352            shape,
353            access,
354            platform,
355            dtype: PhantomData,
356        })
357    }
358}
359
360impl<P: PlatformInstance> Array<f32, AccessOp<P::Normal, P>, P>
361where
362    P: Random,
363{
364    pub fn random_normal(size: usize) -> Result<Self, Error> {
365        let platform = P::select(size);
366        let shape = shape![size];
367
368        platform.random_normal(size).map(|access| Self {
369            shape,
370            access,
371            platform,
372            dtype: PhantomData,
373        })
374    }
375}
376
377impl<P: PlatformInstance> Array<f32, AccessOp<P::Uniform, P>, P>
378where
379    P: Random,
380{
381    pub fn random_uniform(size: usize) -> Result<Self, Error> {
382        let platform = P::select(size);
383        let shape = shape![size];
384
385        platform.random_uniform(size).map(|access| Self {
386            shape,
387            access,
388            platform,
389            dtype: PhantomData,
390        })
391    }
392}
393
394// references
395impl<T, A, P> Array<T, A, P>
396where
397    T: Number,
398    A: Access<T>,
399    P: PlatformInstance,
400{
401    pub fn as_mut<'a, B>(&'a mut self) -> Array<T, B, P>
402    where
403        A: AccessBorrowMut<'a, T, B>,
404        B: AccessMut<T> + 'a,
405    {
406        Array {
407            shape: Shape::from_slice(&self.shape),
408            access: AccessBorrowMut::borrow_mut(&mut self.access),
409            platform: self.platform,
410            dtype: PhantomData,
411        }
412    }
413
414    pub fn as_ref<'a, B>(&'a self) -> Array<T, B, P>
415    where
416        A: AccessBorrow<'a, T, B>,
417        B: Access<T> + 'a,
418    {
419        Array {
420            shape: Shape::from_slice(&self.shape),
421            access: AccessBorrow::borrow(&self.access),
422            platform: self.platform,
423            dtype: PhantomData,
424        }
425    }
426}
427
428// helper methods
429
430impl<'a, T: Number> ArrayAccess<'a, T> {
431    pub fn unstack(
432        self,
433        axis: usize,
434    ) -> Result<Vec<Array<T, impl Access<T> + 'a, Platform>>, Error> {
435        let dim = self
436            .shape()
437            .get(axis)
438            .copied()
439            .ok_or_else(|| Error::bounds(format!("{self:?} has no axis {axis}")))?;
440
441        let prefix = if axis == 0 {
442            Range::with_capacity(1)
443        } else {
444            self.shape
445                .iter()
446                .take(axis)
447                .copied()
448                .map(|dim| AxisRange::In(0, dim, 1))
449                .collect()
450        };
451
452        (0..dim)
453            .map(|r| {
454                let mut range = prefix.clone();
455                range.push(AxisRange::At(r));
456                range
457            })
458            .map(|r| self.clone().slice(r))
459            .collect()
460    }
461}
462
463// traits
464
465/// An n-dimensional array
466pub trait NDArray: Send + Sync {
467    /// The data type of the elements in this array
468    type DType: Number;
469
470    /// The platform used to construct operations on this array.
471    type Platform: PlatformInstance;
472
473    /// Return the number of dimensions in this array.
474    fn ndim(&self) -> usize {
475        self.shape().len()
476    }
477
478    /// Return the number of elements in this array.
479    fn size(&self) -> usize {
480        self.shape().iter().product()
481    }
482
483    /// Borrow the shape of this array.
484    fn shape(&self) -> &[usize];
485}
486
487impl<T, A, P> NDArray for Array<T, A, P>
488where
489    T: Number,
490    A: Access<T>,
491    P: PlatformInstance,
492{
493    type DType = T;
494    type Platform = P;
495
496    fn shape(&self) -> &[usize] {
497        &self.shape
498    }
499}
500
501/// Array absolute value
502pub trait NDArrayAbs: NDArray + Sized {
503    /// The return type of the absolute value operation
504    type Output: Access<<Self::DType as Number>::Abs>;
505
506    /// Construct an absolute value operation.
507    fn abs(
508        self,
509    ) -> Result<Array<<Self::DType as Number>::Abs, Self::Output, Self::Platform>, Error>;
510}
511
512impl<T, A, P> NDArrayAbs for Array<T, A, P>
513where
514    T: Number,
515    A: Access<T>,
516    P: ElementwiseAbs<A, T>,
517{
518    type Output = AccessOp<P::Op, P>;
519
520    fn abs(self) -> Result<Array<T::Abs, Self::Output, Self::Platform>, Error> {
521        self.apply(|platform, access| platform.abs(access))
522    }
523}
524
525/// Access methods for an [`NDArray`]
526pub trait NDArrayRead: NDArray + fmt::Debug + Sized {
527    /// Read the value of this [`NDArray`] into a [`BufferConverter`].
528    fn buffer(&self) -> Result<BufferConverter<'_, Self::DType>, Error>;
529
530    /// Buffer this [`NDArray`] into a new, owned array, allocating only if needed.
531    fn into_read(
532        self,
533    ) -> Result<
534        Array<
535            Self::DType,
536            AccessBuf<<Self::Platform as Convert<Self::DType>>::Buffer>,
537            Self::Platform,
538        >,
539        Error,
540    >
541    where
542        Self::Platform: Convert<Self::DType>;
543
544    /// Read the value at a specific `coord` in this [`NDArray`].
545    fn read_value(&self, coord: &[usize]) -> Result<Self::DType, Error>;
546}
547
548impl<T, A, P> NDArrayRead for Array<T, A, P>
549where
550    T: Number,
551    A: Access<T>,
552    P: PlatformInstance,
553{
554    fn buffer(&self) -> Result<BufferConverter<'_, T>, Error> {
555        self.access.read()
556    }
557
558    fn into_read(self) -> Result<Array<Self::DType, AccessBuf<P::Buffer>, Self::Platform>, Error>
559    where
560        P: Convert<T>,
561    {
562        let buffer = self.buffer().and_then(|buf| self.platform.convert(buf))?;
563        debug_assert_eq!(buffer.len(), self.size());
564
565        Ok(Array {
566            shape: self.shape,
567            access: buffer.into(),
568            platform: self.platform,
569            dtype: self.dtype,
570        })
571    }
572
573    fn read_value(&self, coord: &[usize]) -> Result<T, Error> {
574        valid_coord(coord, self.shape())?;
575
576        let strides = strides_for(self.shape(), self.ndim());
577
578        let offset = coord
579            .iter()
580            .zip(strides)
581            .map(|(i, stride)| i * stride)
582            .sum();
583
584        self.access.read_value(offset)
585    }
586}
587
588/// Access methods for a mutable [`NDArray`]
589pub trait NDArrayWrite: NDArray + fmt::Debug + Sized {
590    /// Overwrite this [`NDArray`] with the value of the `other` array.
591    fn write<O: NDArrayRead<DType = Self::DType>>(&mut self, other: &O) -> Result<(), Error>;
592
593    /// Overwrite this [`NDArray`] with a constant scalar `value`.
594    fn write_value(&mut self, value: Self::DType) -> Result<(), Error>;
595
596    /// Write the given `value` at the given `coord` of this [`NDArray`].
597    fn write_value_at(&mut self, coord: &[usize], value: Self::DType) -> Result<(), Error>;
598}
599
600// write ops
601impl<T, A, P> NDArrayWrite for Array<T, A, P>
602where
603    T: Number,
604    A: AccessMut<T>,
605    P: PlatformInstance,
606{
607    fn write<O>(&mut self, other: &O) -> Result<(), Error>
608    where
609        O: NDArrayRead<DType = Self::DType>,
610    {
611        same_shape("write", self.shape(), other.shape())?;
612        other.buffer().and_then(|buf| self.access.write(buf))
613    }
614
615    fn write_value(&mut self, value: Self::DType) -> Result<(), Error> {
616        self.access.write_value(value)
617    }
618
619    fn write_value_at(&mut self, coord: &[usize], value: Self::DType) -> Result<(), Error> {
620        valid_coord(coord, self.shape())?;
621
622        let offset = coord
623            .iter()
624            .zip(strides_for(self.shape(), self.ndim()))
625            .map(|(i, stride)| i * stride)
626            .sum();
627
628        self.access.write_value_at(offset, value)
629    }
630}
631
632// op traits
633
634/// Array cast operations
635pub trait NDArrayCast<OT: Number>: NDArray + Sized {
636    type Output: Access<OT>;
637
638    /// Construct a new array cast operation.
639    fn cast(self) -> Result<Array<OT, Self::Output, Self::Platform>, Error>;
640}
641
642impl<IT, OT, A, P> NDArrayCast<OT> for Array<IT, A, P>
643where
644    IT: Number,
645    OT: Number,
646    A: Access<IT>,
647    P: ElementwiseCast<A, IT, OT>,
648{
649    type Output = AccessOp<P::Op, P>;
650
651    fn cast(self) -> Result<Array<OT, AccessOp<P::Op, P>, P>, Error> {
652        Ok(Array {
653            shape: self.shape,
654            access: self.platform.cast(self.access)?,
655            platform: self.platform,
656            dtype: PhantomData,
657        })
658    }
659}
660
661/// Axis-wise array reduce operations
662pub trait NDArrayReduce<'a>: NDArray + fmt::Debug {
663    type Output: Access<Self::DType> + 'a;
664
665    /// Construct a max-reduce operation over the given `axes`.
666    fn max(
667        self,
668        axes: Axes,
669        keepdims: bool,
670    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
671    where
672        Self::DType: Real;
673
674    /// Construct a min-reduce operation over the given `axes`.
675    fn min(
676        self,
677        axes: Axes,
678        keepdims: bool,
679    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
680    where
681        Self::DType: Real;
682
683    /// Construct a product-reduce operation over the given `axes`.
684    fn product(
685        self,
686        axes: Axes,
687        keepdims: bool,
688    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
689
690    /// Construct a sum-reduce operation over the given `axes`.
691    fn sum(
692        self,
693        axes: Axes,
694        keepdims: bool,
695    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
696}
697
698impl<'a, T, A, P> NDArrayReduce<'a> for Array<T, A, P>
699where
700    T: Number + 'a,
701    A: Access<T> + 'a,
702    P: Transform<A, T> + ReduceAxes<Accessor<'a, T>, T> + 'a,
703    Accessor<'a, T>: From<A> + From<AccessOp<P::Transpose, P>> + 'a,
704{
705    type Output = AccessOp<P::Op, P>;
706
707    fn max(
708        self,
709        axes: Axes,
710        keepdims: bool,
711    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
712    where
713        T: Real,
714    {
715        self.reduce_axes(axes, keepdims, |platform, access, stride| {
716            ReduceAxes::max(platform, access, stride)
717        })
718    }
719
720    fn min(
721        self,
722        axes: Axes,
723        keepdims: bool,
724    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
725    where
726        T: Real,
727    {
728        self.reduce_axes(axes, keepdims, |platform, access, stride| {
729            ReduceAxes::min(platform, access, stride)
730        })
731    }
732
733    fn product(
734        self,
735        axes: Axes,
736        keepdims: bool,
737    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
738        self.reduce_axes(axes, keepdims, |platform, access, stride| {
739            ReduceAxes::product(platform, access, stride)
740        })
741    }
742
743    fn sum(
744        self,
745        axes: Axes,
746        keepdims: bool,
747    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
748        self.reduce_axes(axes, keepdims, |platform, access, stride| {
749            ReduceAxes::sum(platform, access, stride)
750        })
751    }
752}
753
754/// Array transform operations
755pub trait NDArrayTransform: NDArray + Sized + fmt::Debug {
756    /// The type returned by `broadcast`
757    type Broadcast: Access<Self::DType>;
758
759    /// The type returned by `flip`
760    type Flip: Access<Self::DType>;
761
762    /// The type returned by `slice`
763    type Slice: Access<Self::DType>;
764
765    /// The type returned by `transpose`
766    type Transpose: Access<Self::DType>;
767
768    /// Broadcast this array into the given `shape`.
769    fn broadcast(
770        self,
771        shape: Shape,
772    ) -> Result<Array<Self::DType, Self::Broadcast, Self::Platform>, Error>;
773
774    fn flip(self, axis: usize) -> Result<Array<Self::DType, Self::Flip, Self::Platform>, Error>;
775
776    /// Reshape this `array`.
777    fn reshape(self, shape: Shape) -> Result<Self, Error>;
778
779    /// Construct a slice of this array.
780    fn slice(self, range: Range) -> Result<Array<Self::DType, Self::Slice, Self::Platform>, Error>;
781
782    /// Contract the given `axes` of this array.
783    /// This will return an error if any of the `axes` have dimension > 1.
784    fn squeeze(self, axes: Axes) -> Result<Self, Error>;
785
786    /// Expand the given `axes` of this array.
787    fn unsqueeze(self, axes: Axes) -> Result<Self, Error>;
788
789    /// Transpose this array according to the given `permutation`.
790    /// If no permutation is given, the array axes will be reversed.
791    fn transpose<P: Into<Option<Axes>>>(
792        self,
793        permutation: P,
794    ) -> Result<Array<Self::DType, Self::Transpose, Self::Platform>, Error>;
795}
796
797impl<T, A, P> NDArrayTransform for Array<T, A, P>
798where
799    T: Number,
800    A: Access<T>,
801    P: Transform<A, T>,
802{
803    type Broadcast = AccessOp<P::Broadcast, P>;
804    type Flip = AccessOp<P::Flip, P>;
805    type Slice = AccessOp<P::Slice, P>;
806    type Transpose = AccessOp<P::Transpose, P>;
807
808    fn broadcast(self, shape: Shape) -> Result<Array<T, AccessOp<P::Broadcast, P>, P>, Error> {
809        if !can_broadcast(self.shape(), &shape) {
810            return Err(Error::bounds(format!(
811                "cannot broadcast {self:?} into {shape:?}"
812            )));
813        }
814
815        let platform = P::select(shape.iter().product());
816        let broadcast = Shape::from_slice(&shape);
817        let access = platform.broadcast(self.access, self.shape, broadcast)?;
818
819        Ok(Array {
820            shape,
821            access,
822            platform,
823            dtype: self.dtype,
824        })
825    }
826
827    fn flip(self, axis: usize) -> Result<Array<T, AccessOp<P::Flip, P>, P>, Error> {
828        let platform = self.platform;
829        let access = platform.flip(self.access, self.shape.clone(), axis)?;
830
831        Ok(Array {
832            shape: self.shape,
833            access,
834            platform,
835            dtype: self.dtype,
836        })
837    }
838
839    fn reshape(mut self, shape: Shape) -> Result<Self, Error> {
840        if shape.iter().product::<usize>() == self.size() {
841            self.shape = shape;
842            Ok(self)
843        } else {
844            Err(Error::bounds(format!(
845                "cannot reshape an array with shape {:?} into {shape:?}",
846                self.shape
847            )))
848        }
849    }
850
851    fn slice(self, mut range: Range) -> Result<Array<T, AccessOp<P::Slice, P>, P>, Error> {
852        for (dim, range) in self.shape.iter().zip(&range) {
853            match range {
854                AxisRange::At(i) if i < dim => Ok(()),
855                AxisRange::In(start, stop, _step) if start < dim && stop <= dim => Ok(()),
856                AxisRange::Of(indices) if indices.iter().all(|i| i < dim) => Ok(()),
857                range => Err(Error::bounds(format!(
858                    "invalid range {range:?} for dimension {dim}"
859                ))),
860            }?;
861        }
862
863        for dim in self.shape.iter().skip(range.len()).copied() {
864            range.push(AxisRange::In(0, dim, 1));
865        }
866
867        let shape = range_shape(self.shape(), &range);
868        let access = self.platform.slice(self.access, &self.shape, range)?;
869        let platform = P::select(shape.iter().product());
870
871        Ok(Array {
872            shape,
873            access,
874            platform,
875            dtype: self.dtype,
876        })
877    }
878
879    fn squeeze(mut self, mut axes: Axes) -> Result<Self, Error> {
880        axes.sort();
881
882        for x in axes.into_iter().rev() {
883            if x < self.shape.len() {
884                self.shape.remove(x);
885            } else {
886                return Err(Error::bounds(format!("axis out of bounds: {x}")));
887            }
888        }
889
890        Ok(self)
891    }
892
893    fn unsqueeze(mut self, axes: Axes) -> Result<Self, Error> {
894        for x in axes {
895            if x <= self.shape.len() {
896                self.shape.insert(x, 1);
897            } else {
898                return Err(Error::bounds(format!("axis out of bounds: {x}")));
899            }
900        }
901
902        Ok(self)
903    }
904
905    fn transpose<PA: Into<Option<Axes>>>(
906        self,
907        permutation: PA,
908    ) -> Result<Array<T, AccessOp<P::Transpose, P>, P>, Error> {
909        let permutation = if let Some(axes) = permutation.into() {
910            if axes.len() == self.ndim()
911                && axes.iter().copied().all(|x| x < self.ndim())
912                && !(1..axes.len()).any(|i| axes[i..].contains(&axes[i - 1]))
913            {
914                Ok(axes)
915            } else {
916                Err(Error::bounds(format!(
917                    "invalid permutation for shape {:?}: {:?}",
918                    self.shape, axes
919                )))
920            }
921        } else {
922            Ok((0..self.ndim()).rev().collect())
923        }?;
924
925        let shape = permutation.iter().copied().map(|x| self.shape[x]).collect();
926        let platform = self.platform;
927        let access = platform.transpose(self.access, self.shape, permutation)?;
928
929        Ok(Array {
930            shape,
931            access,
932            platform,
933            dtype: self.dtype,
934        })
935    }
936}
937
938/// Unary array operations
939pub trait NDArrayUnary: NDArray + Sized {
940    /// The return type of a unary operation.
941    type Output: Access<Self::DType>;
942
943    /// Construct an exponentiation operation.
944    fn exp(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
945
946    /// Construct a natural logarithm operation.
947    fn ln(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
948
949    /// Construct an integer rounding operation.
950    fn round(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
951    where
952        Self::DType: Real;
953}
954
955impl<T, A, P> NDArrayUnary for Array<T, A, P>
956where
957    T: Float,
958    A: Access<T>,
959    P: ElementwiseUnary<A, T>,
960{
961    type Output = AccessOp<P::Op, P>;
962
963    fn exp(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
964        self.apply(|platform, access| platform.exp(access))
965    }
966
967    fn ln(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
968    where
969        P: ElementwiseUnary<A, T>,
970    {
971        self.apply(|platform, access| platform.ln(access))
972    }
973
974    fn round(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
975    where
976        T: Real,
977    {
978        self.apply(|platform, access| platform.round(access))
979    }
980}
981
982/// Unary boolean array operations
983pub trait NDArrayUnaryBoolean: NDArray + Sized {
984    /// The return type of a unary operation.
985    type Output: Access<u8>;
986
987    /// Construct a boolean not operation.
988    fn not(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
989}
990
991impl<T, A, P> NDArrayUnaryBoolean for Array<T, A, P>
992where
993    T: Number,
994    A: Access<T>,
995    P: ElementwiseUnaryBoolean<A, T>,
996{
997    type Output = AccessOp<P::Op, P>;
998
999    fn not(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1000        self.apply(|platform, access| platform.not(access))
1001    }
1002}
1003
1004/// Boolean array operations
1005pub trait NDArrayBoolean<O>: NDArray + Sized
1006where
1007    O: NDArray<DType = Self::DType>,
1008{
1009    type Output: Access<u8>;
1010
1011    /// Construct a boolean and comparison with the `other` array.
1012    fn and(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1013
1014    /// Construct a boolean or comparison with the `other` array.
1015    fn or(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1016
1017    /// Construct a boolean xor comparison with the `other` array.
1018    fn xor(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1019}
1020
1021impl<T, L, R, P> NDArrayBoolean<Array<T, R, P>> for Array<T, L, P>
1022where
1023    T: Number,
1024    L: Access<T>,
1025    R: Access<T>,
1026    P: ElementwiseBoolean<L, R, T>,
1027{
1028    type Output = AccessOp<P::Op, P>;
1029
1030    fn and(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1031        same_shape("and", self.shape(), other.shape())?;
1032        self.apply_dual(other, |platform, left, right| platform.and(left, right))
1033    }
1034
1035    fn or(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1036        same_shape("or", self.shape(), other.shape())?;
1037        self.apply_dual(other, |platform, left, right| platform.or(left, right))
1038    }
1039
1040    fn xor(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1041        same_shape("xor", self.shape(), other.shape())?;
1042        self.apply_dual(other, |platform, left, right| platform.xor(left, right))
1043    }
1044}
1045
1046/// Boolean array operations with a scalar argument
1047pub trait NDArrayBooleanScalar: NDArray + Sized {
1048    type Output: Access<u8>;
1049
1050    /// Construct a boolean and operation with the `other` value.
1051    fn and_scalar(
1052        self,
1053        other: Self::DType,
1054    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1055
1056    /// Construct a boolean or operation with the `other` value.
1057    fn or_scalar(
1058        self,
1059        other: Self::DType,
1060    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1061
1062    /// Construct a boolean xor operation with the `other` value.
1063    fn xor_scalar(
1064        self,
1065        other: Self::DType,
1066    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1067}
1068
1069impl<T, A, P> NDArrayBooleanScalar for Array<T, A, P>
1070where
1071    T: Number,
1072    A: Access<T>,
1073    P: ElementwiseBooleanScalar<A, T>,
1074{
1075    type Output = AccessOp<P::Op, P>;
1076
1077    fn and_scalar(
1078        self,
1079        other: Self::DType,
1080    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1081        self.apply(|platform, access| platform.and_scalar(access, other))
1082    }
1083
1084    fn or_scalar(
1085        self,
1086        other: Self::DType,
1087    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1088        self.apply(|platform, access| platform.or_scalar(access, other))
1089    }
1090
1091    fn xor_scalar(
1092        self,
1093        other: Self::DType,
1094    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1095        self.apply(|platform, access| platform.xor_scalar(access, other))
1096    }
1097}
1098
1099/// Array comparison operations
1100pub trait NDArrayCompare<O: NDArray<DType = Self::DType>>: NDArray + Sized {
1101    type Output: Access<u8>;
1102
1103    /// Elementwise equality comparison
1104    fn eq(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1105
1106    /// Elementwise greater-than-or-equal comparison
1107    fn ge(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1108    where
1109        Self::DType: Real;
1110
1111    /// Elementwise greater-than comparison
1112    fn gt(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1113    where
1114        Self::DType: Real;
1115
1116    /// Elementwise less-than-or-equal comparison
1117    fn le(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1118    where
1119        Self::DType: Real;
1120
1121    /// Elementwise less-than comparison
1122    fn lt(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1123    where
1124        Self::DType: Real;
1125
1126    /// Elementwise not-equal comparison
1127    fn ne(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1128}
1129
1130impl<T, L, R, P> NDArrayCompare<Array<T, R, P>> for Array<T, L, P>
1131where
1132    T: Number,
1133    L: Access<T>,
1134    R: Access<T>,
1135    P: ElementwiseCompare<L, R, T>,
1136{
1137    type Output = AccessOp<P::Op, P>;
1138
1139    fn eq(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1140        same_shape("compare", self.shape(), other.shape())?;
1141        self.apply_dual(other, |platform, left, right| platform.eq(left, right))
1142    }
1143
1144    fn ge(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1145    where
1146        T: Real,
1147    {
1148        same_shape("compare", self.shape(), other.shape())?;
1149        self.apply_dual(other, |platform, left, right| platform.ge(left, right))
1150    }
1151
1152    fn gt(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1153    where
1154        T: Real,
1155    {
1156        same_shape("compare", self.shape(), other.shape())?;
1157        self.apply_dual(other, |platform, left, right| platform.gt(left, right))
1158    }
1159
1160    fn le(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1161    where
1162        T: Real,
1163    {
1164        same_shape("compare", self.shape(), other.shape())?;
1165        self.apply_dual(other, |platform, left, right| platform.le(left, right))
1166    }
1167
1168    fn lt(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1169    where
1170        T: Real,
1171    {
1172        same_shape("compare", self.shape(), other.shape())?;
1173        self.apply_dual(other, |platform, left, right| platform.lt(left, right))
1174    }
1175
1176    fn ne(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1177        same_shape("compare", self.shape(), other.shape())?;
1178        self.apply_dual(other, |platform, left, right| platform.ne(left, right))
1179    }
1180}
1181
1182/// Array-scalar comparison operations
1183pub trait NDArrayCompareScalar: NDArray + Sized {
1184    type Output: Access<u8>;
1185
1186    /// Construct an equality comparison with the `other` value.
1187    fn eq_scalar(
1188        self,
1189        other: Self::DType,
1190    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1191
1192    /// Construct a greater-than comparison with the `other` value.
1193    fn gt_scalar(
1194        self,
1195        other: Self::DType,
1196    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1197    where
1198        Self::DType: Real;
1199
1200    /// Construct an equal-or-greater-than comparison with the `other` value.
1201    fn ge_scalar(
1202        self,
1203        other: Self::DType,
1204    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1205    where
1206        Self::DType: Real;
1207
1208    /// Construct a less-than comparison with the `other` value.
1209    fn lt_scalar(
1210        self,
1211        other: Self::DType,
1212    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1213    where
1214        Self::DType: Real;
1215
1216    /// Construct an equal-or-less-than comparison with the `other` value.
1217    fn le_scalar(
1218        self,
1219        other: Self::DType,
1220    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1221    where
1222        Self::DType: Real;
1223
1224    /// Construct an not-equal comparison with the `other` value.
1225    fn ne_scalar(
1226        self,
1227        other: Self::DType,
1228    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1229}
1230
1231impl<T, A, P> NDArrayCompareScalar for Array<T, A, P>
1232where
1233    T: Number,
1234    A: Access<T>,
1235    P: ElementwiseCompareScalar<A, T>,
1236{
1237    type Output = AccessOp<P::Op, P>;
1238
1239    fn eq_scalar(
1240        self,
1241        other: Self::DType,
1242    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1243        self.apply(|platform, access| platform.eq_scalar(access, other))
1244    }
1245
1246    fn gt_scalar(self, other: Self::DType) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1247    where
1248        T: Real,
1249    {
1250        self.apply(|platform, access| platform.gt_scalar(access, other))
1251    }
1252
1253    fn ge_scalar(self, other: Self::DType) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1254    where
1255        T: Real,
1256    {
1257        self.apply(|platform, access| platform.ge_scalar(access, other))
1258    }
1259
1260    fn lt_scalar(self, other: Self::DType) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1261    where
1262        T: Real,
1263    {
1264        self.apply(|platform, access| platform.lt_scalar(access, other))
1265    }
1266
1267    fn le_scalar(self, other: Self::DType) -> Result<Array<u8, Self::Output, Self::Platform>, Error>
1268    where
1269        T: Real,
1270    {
1271        self.apply(|platform, access| platform.le_scalar(access, other))
1272    }
1273
1274    fn ne_scalar(
1275        self,
1276        other: Self::DType,
1277    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1278        self.apply(|platform, access| platform.ne_scalar(access, other))
1279    }
1280}
1281
1282#[cfg(feature = "complex")]
1283/// Complex array properties
1284pub trait NDArrayComplex: NDArray + Sized
1285where
1286    Self::DType: Complex,
1287{
1288    type Real: Access<<Self::DType as Complex>::Real>;
1289    type Complex: Access<Self::DType>;
1290
1291    /// Calculate the angle in the complex plane elementwise.
1292    fn angle(
1293        self,
1294    ) -> Result<Array<<Self::DType as Complex>::Real, Self::Real, Self::Platform>, Error>;
1295
1296    /// Calculate the angle in the complex plane elementwise.
1297    fn conj(self) -> Result<Array<Self::DType, Self::Complex, Self::Platform>, Error>;
1298
1299    /// Return the real part of this array elementwise.
1300    fn re(self) -> Result<Array<Self::DType, Self::Real, Self::Platform>, Error>;
1301
1302    /// Return the imaginary part of this array elementwise.
1303    fn im(self) -> Result<Array<Self::DType, Self::Real, Self::Platform>, Error>;
1304}
1305
1306#[cfg(feature = "complex")]
1307impl<T, A, P> NDArrayComplex for Array<T, A, P>
1308where
1309    T: Complex,
1310    A: Access<T>,
1311    P: complex::ElementwiseUnaryComplex<A, T>,
1312{
1313    type Real = AccessOp<P::Real, P>;
1314    type Complex = AccessOp<P::Complex, P>;
1315
1316    fn angle(self) -> Result<Array<T::Real, Self::Real, Self::Platform>, Error> {
1317        self.apply(|platform, access| platform.angle(access))
1318    }
1319
1320    fn conj(self) -> Result<Array<Self::DType, Self::Complex, Self::Platform>, Error> {
1321        self.apply(|platform, access| platform.conj(access))
1322    }
1323
1324    fn re(self) -> Result<Array<Self::DType, Self::Real, Self::Platform>, Error> {
1325        self.apply(|platform, access| platform.re(access))
1326    }
1327
1328    fn im(self) -> Result<Array<Self::DType, Self::Real, Self::Platform>, Error> {
1329        self.apply(|platform, access| platform.im(access))
1330    }
1331}
1332
1333#[cfg(feature = "complex")]
1334/// Fourier transforms
1335pub trait NDArrayFourier: NDArray + Sized
1336where
1337    Self::DType: Complex,
1338{
1339    type Output: Access<Self::DType>;
1340
1341    /// Calculate the Fourier transform of the last dimension of this array.
1342    fn fft(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1343
1344    /// Calculate the Fourier transform of the last dimension of this array.
1345    fn ifft(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1346}
1347
1348#[cfg(feature = "complex")]
1349impl<A, T, P> NDArrayFourier for Array<num_complex::Complex<T>, A, P>
1350where
1351    A: Access<num_complex::Complex<T>>,
1352    num_complex::Complex<T>: Complex,
1353    P: complex::Fourier<A, num_complex::Complex<T>>,
1354{
1355    type Output = AccessOp<P::Op, P>;
1356
1357    fn fft(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1358        let dim = self
1359            .shape
1360            .last()
1361            .copied()
1362            .ok_or_else(|| Error::bounds("a scalar value has no Fourier transform".into()))?;
1363
1364        self.apply(|platform, access| platform.fft(access, dim))
1365    }
1366
1367    fn ifft(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1368        let dim = self
1369            .shape
1370            .last()
1371            .copied()
1372            .ok_or_else(|| Error::bounds("a scalar value has no Fourier transform".into()))?;
1373
1374        self.apply(|platform, access| platform.ifft(access, dim))
1375    }
1376}
1377
1378// TODO: it should be possible to implement this with a different other DType, e.g. C32 * f32
1379/// Array arithmetic operations
1380pub trait NDArrayMath<O: NDArray<DType = Self::DType>>: NDArray + Sized {
1381    type Output: Access<Self::DType>;
1382
1383    /// Construct an addition operation with the given `rhs`.
1384    fn add(self, rhs: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1385
1386    /// Construct a division operation with the given `rhs`.
1387    fn div(self, rhs: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1388
1389    /// Construct a logarithm operation with the given `base`.
1390    fn log(self, base: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
1391    where
1392        Self::DType: Float;
1393
1394    /// Construct a multiplication operation with the given `rhs`.
1395    fn mul(self, rhs: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1396
1397    /// Construct an operation to raise these data to the power of the given `exp`onent.
1398    fn pow(self, exp: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1399
1400    /// Construct an array subtraction operation with the given `rhs`.
1401    fn sub(self, rhs: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1402
1403    /// Construct a modulo operation with the given `rhs`.
1404    fn rem(self, rhs: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
1405    where
1406        Self::DType: Real;
1407}
1408
1409impl<T, L, R, P> NDArrayMath<Array<T, R, P>> for Array<T, L, P>
1410where
1411    T: Number,
1412    L: Access<T>,
1413    R: Access<T>,
1414    P: ElementwiseDual<L, R, T>,
1415{
1416    type Output = AccessOp<P::Op, P>;
1417
1418    fn add(
1419        self,
1420        rhs: Array<T, R, P>,
1421    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1422        same_shape("add", self.shape(), rhs.shape())?;
1423        self.apply_dual(rhs, |platform, left, right| platform.add(left, right))
1424    }
1425
1426    fn div(
1427        self,
1428        rhs: Array<T, R, P>,
1429    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1430        same_shape("divide", self.shape(), rhs.shape())?;
1431        self.apply_dual(rhs, |platform, left, right| platform.div(left, right))
1432    }
1433
1434    fn log(
1435        self,
1436        base: Array<T, R, P>,
1437    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
1438    where
1439        T: Float,
1440    {
1441        same_shape("log", self.shape(), base.shape())?;
1442        self.apply_dual(base, |platform, left, right| platform.log(left, right))
1443    }
1444
1445    fn mul(
1446        self,
1447        rhs: Array<T, R, P>,
1448    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1449        same_shape("multiply", self.shape(), rhs.shape())?;
1450        self.apply_dual(rhs, |platform, left, right| platform.mul(left, right))
1451    }
1452
1453    fn pow(
1454        self,
1455        exp: Array<T, R, P>,
1456    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1457        same_shape("exponentiate", self.shape(), exp.shape())?;
1458        self.apply_dual(exp, |platform, left, right| platform.pow(left, right))
1459    }
1460
1461    fn sub(
1462        self,
1463        rhs: Array<T, R, P>,
1464    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1465        same_shape("subtract", self.shape(), rhs.shape())?;
1466        self.apply_dual(rhs, |platform, left, right| platform.sub(left, right))
1467    }
1468
1469    fn rem(
1470        self,
1471        rhs: Array<T, R, P>,
1472    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
1473    where
1474        T: Real,
1475    {
1476        same_shape("remainder", self.shape(), rhs.shape())?;
1477        self.apply_dual(rhs, |platform, left, right| platform.rem(left, right))
1478    }
1479}
1480
1481/// Array arithmetic operations with a scalar argument
1482pub trait NDArrayMathScalar: NDArray + Sized {
1483    type Output: Access<Self::DType>;
1484
1485    /// Construct a scalar addition operation.
1486    fn add_scalar(
1487        self,
1488        rhs: Self::DType,
1489    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1490
1491    /// Construct a scalar division operation.
1492    fn div_scalar(
1493        self,
1494        rhs: Self::DType,
1495    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1496
1497    /// Construct a scalar logarithm operation.
1498    fn log_scalar(
1499        self,
1500        base: Self::DType,
1501    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
1502    where
1503        Self::DType: Float;
1504
1505    /// Construct a scalar multiplication operation.
1506    fn mul_scalar(
1507        self,
1508        rhs: Self::DType,
1509    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1510
1511    /// Construct a scalar exponentiation operation.
1512    fn pow_scalar(
1513        self,
1514        exp: Self::DType,
1515    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1516
1517    /// Construct a scalar modulo operation.
1518    fn rem_scalar(
1519        self,
1520        rhs: Self::DType,
1521    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
1522    where
1523        Self::DType: Real;
1524
1525    /// Construct a scalar subtraction operation.
1526    fn sub_scalar(
1527        self,
1528        rhs: Self::DType,
1529    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1530}
1531
1532impl<T, A, P> NDArrayMathScalar for Array<T, A, P>
1533where
1534    T: Number,
1535    A: Access<T>,
1536    P: ElementwiseScalar<A, T>,
1537{
1538    type Output = AccessOp<P::Op, P>;
1539
1540    fn add_scalar(
1541        self,
1542        rhs: Self::DType,
1543    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1544        self.apply(|platform, left| platform.add_scalar(left, rhs))
1545    }
1546
1547    fn div_scalar(
1548        self,
1549        rhs: Self::DType,
1550    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1551        if rhs == T::ZERO {
1552            Err(Error::unsupported(format!(
1553                "cannot divide {self:?} by {rhs}"
1554            )))
1555        } else {
1556            self.apply(|platform, left| platform.div_scalar(left, rhs))
1557        }
1558    }
1559
1560    fn log_scalar(
1561        self,
1562        base: Self::DType,
1563    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
1564    where
1565        Self::DType: Float,
1566    {
1567        self.apply(|platform, arg| platform.log_scalar(arg, base))
1568    }
1569
1570    fn mul_scalar(
1571        self,
1572        rhs: Self::DType,
1573    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1574        self.apply(|platform, left| platform.mul_scalar(left, rhs))
1575    }
1576
1577    fn pow_scalar(
1578        self,
1579        exp: Self::DType,
1580    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1581        self.apply(|platform, arg| platform.pow_scalar(arg, exp))
1582    }
1583
1584    fn rem_scalar(
1585        self,
1586        rhs: Self::DType,
1587    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
1588    where
1589        Self::DType: Real,
1590    {
1591        self.apply(|platform, left| platform.rem_scalar(left, rhs))
1592    }
1593
1594    fn sub_scalar(
1595        self,
1596        rhs: Self::DType,
1597    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1598        self.apply(|platform, left| platform.sub_scalar(left, rhs))
1599    }
1600}
1601
1602/// Float-specific array methods
1603pub trait NDArrayNumeric: NDArray + Sized
1604where
1605    Self::DType: Float,
1606{
1607    type Output: Access<u8>;
1608
1609    /// Test which elements of this array are infinite.
1610    #[allow(clippy::wrong_self_convention)]
1611    // This consumes `self` to allow in-place graph construction, consistent with other ndarray ops.
1612    fn is_inf(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1613
1614    /// Test which elements of this array are not-a-number.
1615    #[allow(clippy::wrong_self_convention)]
1616    // This consumes `self` to allow in-place graph construction, consistent with other ndarray ops.
1617    fn is_nan(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1618}
1619
1620impl<T, A, P> NDArrayNumeric for Array<T, A, P>
1621where
1622    T: Float,
1623    A: Access<T>,
1624    P: ElementwiseNumeric<A, T>,
1625{
1626    type Output = AccessOp<P::Op, P>;
1627
1628    fn is_inf(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1629        self.apply(|platform, access| platform.is_inf(access))
1630    }
1631
1632    fn is_nan(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1633        self.apply(|platform, access| platform.is_nan(access))
1634    }
1635}
1636
1637/// Boolean array reduce operations
1638pub trait NDArrayReduceBoolean: NDArrayRead {
1639    /// Return `true` if this array contains only non-zero elements.
1640    fn all(self) -> Result<bool, Error>;
1641
1642    /// Return `true` if this array contains any non-zero elements.
1643    fn any(self) -> Result<bool, Error>;
1644}
1645
1646impl<T, A, P> NDArrayReduceBoolean for Array<T, A, P>
1647where
1648    T: Number,
1649    A: Access<T>,
1650    P: ReduceAll<A, T>,
1651{
1652    fn all(self) -> Result<bool, Error> {
1653        self.platform.all(self.access)
1654    }
1655
1656    fn any(self) -> Result<bool, Error> {
1657        self.platform.any(self.access)
1658    }
1659}
1660
1661/// Array reduce operations
1662pub trait NDArrayReduceAll: NDArrayRead {
1663    /// Return the maximum of all elements in this array.
1664    fn max_all(self) -> Result<Self::DType, Error>
1665    where
1666        Self::DType: Real;
1667
1668    /// Return the minimum of all elements in this array.
1669    fn min_all(self) -> Result<Self::DType, Error>
1670    where
1671        Self::DType: Real;
1672
1673    /// Return the product of all elements in this array.
1674    fn product_all(self) -> Result<Self::DType, Error>;
1675
1676    /// Return the sum of all elements in this array.
1677    fn sum_all(self) -> Result<Self::DType, Error>;
1678}
1679
1680impl<T, A, P> NDArrayReduceAll for Array<T, A, P>
1681where
1682    T: Number,
1683    A: Access<T>,
1684    P: ReduceAll<A, T>,
1685{
1686    fn max_all(self) -> Result<Self::DType, Error>
1687    where
1688        T: Real,
1689    {
1690        self.platform.max(self.access)
1691    }
1692
1693    fn min_all(self) -> Result<Self::DType, Error>
1694    where
1695        T: Real,
1696    {
1697        self.platform.min(self.access)
1698    }
1699
1700    fn product_all(self) -> Result<Self::DType, Error> {
1701        self.platform.product(self.access)
1702    }
1703
1704    fn sum_all(self) -> Result<T, Error> {
1705        self.platform.sum(self.access)
1706    }
1707}
1708
1709impl<T, A, P> fmt::Debug for Array<T, A, P> {
1710    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1711        write!(
1712            f,
1713            "a {} array of shape {:?}",
1714            std::any::type_name::<T>(),
1715            self.shape
1716        )
1717    }
1718}
1719
1720/// Array trigonometry methods
1721pub trait NDArrayTrig: NDArray + Sized {
1722    type Output: Access<Self::DType>;
1723
1724    /// Construct a new sine operation.
1725    fn sin(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1726
1727    /// Construct a new arcsine operation.
1728    fn asin(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1729
1730    /// Construct a new hyperbolic sine operation.
1731    fn sinh(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1732
1733    /// Construct a new cos operation.
1734    fn cos(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1735
1736    /// Construct a new arccosine operation.
1737    fn acos(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1738
1739    /// Construct a new hyperbolic cosine operation.
1740    fn cosh(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1741
1742    /// Construct a new tangent operation.
1743    fn tan(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1744
1745    /// Construct a new arctangent operation.
1746    fn atan(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1747
1748    /// Construct a new hyperbolic tangent operation.
1749    fn tanh(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1750}
1751
1752impl<T, A, P> NDArrayTrig for Array<T, A, P>
1753where
1754    T: Float,
1755    A: Access<T>,
1756    P: ElementwiseTrig<A, T>,
1757{
1758    type Output = AccessOp<P::Op, P>;
1759
1760    fn sin(self) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1761        self.apply(|platform, access| platform.sin(access))
1762    }
1763
1764    fn asin(self) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1765        self.apply(|platform, access| platform.asin(access))
1766    }
1767
1768    fn sinh(self) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1769        self.apply(|platform, access| platform.sinh(access))
1770    }
1771
1772    fn cos(self) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1773        self.apply(|platform, access| platform.cos(access))
1774    }
1775
1776    fn acos(self) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1777        self.apply(|platform, access| platform.acos(access))
1778    }
1779
1780    fn cosh(self) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1781        self.apply(|platform, access| platform.cosh(access))
1782    }
1783
1784    fn tan(self) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1785        self.apply(|platform, access| platform.tan(access))
1786    }
1787
1788    fn atan(self) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1789        self.apply(|platform, access| platform.atan(access))
1790    }
1791
1792    fn tanh(self) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1793        self.apply(|platform, access| platform.tanh(access))
1794    }
1795}
1796
1797/// Conditional selection (boolean logic) methods
1798pub trait NDArrayWhere<T, L, R>: NDArray<DType = u8> + fmt::Debug
1799where
1800    T: Number,
1801{
1802    type Output: Access<T>;
1803
1804    /// Construct a boolean selection operation.
1805    /// The resulting array will return values from `then` where `self` is `true`
1806    /// and from `or_else` where `self` is `false`.
1807    fn cond(self, then: L, or_else: R) -> Result<Array<T, Self::Output, Self::Platform>, Error>;
1808}
1809
1810impl<T, A, L, R, P> NDArrayWhere<T, Array<T, L, P>, Array<T, R, P>> for Array<u8, A, P>
1811where
1812    T: Number,
1813    A: Access<u8>,
1814    L: Access<T>,
1815    R: Access<T>,
1816    P: GatherCond<A, L, R, T>,
1817{
1818    type Output = AccessOp<P::Op, P>;
1819
1820    fn cond(
1821        self,
1822        then: Array<T, L, P>,
1823        or_else: Array<T, R, P>,
1824    ) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1825        same_shape("cond", self.shape(), then.shape())?;
1826        same_shape("cond", self.shape(), or_else.shape())?;
1827
1828        let access = self
1829            .platform
1830            .cond(self.access, then.access, or_else.access)?;
1831
1832        Ok(Array {
1833            shape: self.shape,
1834            access,
1835            platform: self.platform,
1836            dtype: PhantomData,
1837        })
1838    }
1839}
1840
1841/// Matrix dual operations
1842pub trait MatrixDual<O>: NDArray + fmt::Debug
1843where
1844    O: NDArray<DType = Self::DType> + fmt::Debug,
1845{
1846    type Output: Access<Self::DType>;
1847
1848    /// Construct an operation to multiply this matrix or batch of matrices with the `other`.
1849    fn matmul(self, other: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1850}
1851
1852impl<T, L, R, P> MatrixDual<Array<T, R, P>> for Array<T, L, P>
1853where
1854    T: Number,
1855    L: Access<T>,
1856    R: Access<T>,
1857    P: LinAlgDual<L, R, T>,
1858{
1859    type Output = AccessOp<P::Op, P>;
1860
1861    fn matmul(
1862        self,
1863        other: Array<T, R, P>,
1864    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1865        let dims = matmul_dims(&self.shape, &other.shape).ok_or_else(|| {
1866            Error::bounds(format!(
1867                "invalid dimensions for matrix multiply: {:?} and {:?}",
1868                self.shape, other.shape
1869            ))
1870        })?;
1871
1872        let mut shape = Shape::with_capacity(self.ndim());
1873        shape.extend(self.shape.iter().rev().skip(2).rev().copied());
1874        shape.push(dims[1]);
1875        shape.push(dims[3]);
1876
1877        let platform = P::select(dims.iter().product());
1878
1879        let access = platform.matmul(self.access, other.access, dims)?;
1880
1881        Ok(Array {
1882            shape,
1883            access,
1884            platform,
1885            dtype: self.dtype,
1886        })
1887    }
1888}
1889
1890/// Matrix unary operations
1891pub trait MatrixUnary: NDArray + fmt::Debug {
1892    type Diag: Access<Self::DType>;
1893    type Transpose: Access<Self::DType>;
1894
1895    /// Transpose a matrix or batch of matrices (i.e., transpose the last two dimensions).
1896    fn mt(self) -> Result<Array<Self::DType, Self::Transpose, Self::Platform>, Error>;
1897
1898    /// Construct an operation to read the diagonal(s) of this matrix or batch of matrices.
1899    /// This will return an error if the last two dimensions of the batch are unequal.
1900    fn diag(self) -> Result<Array<Self::DType, Self::Diag, Self::Platform>, Error>;
1901}
1902
1903impl<T, A, P> MatrixUnary for Array<T, A, P>
1904where
1905    T: Number,
1906    A: Access<T>,
1907    P: LinAlgUnary<A, T> + Transform<A, T>,
1908{
1909    type Diag = AccessOp<<P as LinAlgUnary<A, T>>::Op, P>;
1910    type Transpose = AccessOp<<P as Transform<A, T>>::Transpose, P>;
1911
1912    fn mt(self) -> Result<Array<Self::DType, Self::Transpose, Self::Platform>, Error> {
1913        let ndim = self.ndim();
1914        let mut permutation = Axes::with_capacity(ndim);
1915        permutation.extend(0..self.ndim() - 2);
1916        permutation.push(ndim - 1);
1917        permutation.push(ndim - 2);
1918        self.transpose(permutation)
1919    }
1920
1921    fn diag(self) -> Result<Array<T, AccessOp<P::Op, P>, P>, Error> {
1922        if self.ndim() >= 2 && self.shape.last() == self.shape.iter().nth_back(1) {
1923            let batch_size = self.shape.iter().rev().skip(2).product();
1924            let dim = self.shape.last().copied().expect("dim");
1925
1926            let shape = self.shape.iter().rev().skip(1).rev().copied().collect();
1927            let platform = P::select(batch_size * dim * dim);
1928            let access = platform.diag(self.access, batch_size, dim)?;
1929
1930            Ok(Array {
1931                shape,
1932                access,
1933                platform,
1934                dtype: PhantomData,
1935            })
1936        } else {
1937            Err(Error::bounds(format!(
1938                "invalid shape for diagonal: {:?}",
1939                self.shape
1940            )))
1941        }
1942    }
1943}
1944
1945#[cfg(feature = "complex")]
1946/// Complex matrix unary operations
1947pub trait MatrixUnaryComplex: MatrixUnary
1948where
1949    Self::DType: Complex,
1950{
1951    type Hermitian: Access<Self::DType>;
1952
1953    /// Construct the conjugate transpose of a matrix or batch of matrices.
1954    fn mh(self) -> Result<Array<Self::DType, Self::Hermitian, Self::Platform>, Error>;
1955}
1956
1957#[cfg(feature = "complex")]
1958impl<T, A, P> MatrixUnaryComplex for Array<T, A, P>
1959where
1960    T: Complex,
1961    A: Access<T>,
1962    P: complex::ElementwiseUnaryComplex<Self::Transpose, T> + LinAlgUnary<A, T> + Transform<A, T>,
1963{
1964    type Hermitian = AccessOp<P::Complex, P>;
1965
1966    fn mh(self) -> Result<Array<Self::DType, Self::Hermitian, Self::Platform>, Error> {
1967        self.mt().and_then(|array| array.conj())
1968    }
1969}
1970
1971#[inline]
1972fn can_broadcast(left: &[usize], right: &[usize]) -> bool {
1973    if left.len() < right.len() {
1974        return can_broadcast(right, left);
1975    }
1976
1977    for (l, r) in left.iter().copied().rev().zip(right.iter().copied().rev()) {
1978        if l == r || l == 1 || r == 1 {
1979            // pass
1980        } else {
1981            return false;
1982        }
1983    }
1984
1985    true
1986}
1987
1988#[inline]
1989fn matmul_dims(left: &[usize], right: &[usize]) -> Option<[usize; 4]> {
1990    let mut left = left.iter().copied().rev();
1991    let mut right = right.iter().copied().rev();
1992
1993    let b = left.next()?;
1994    let a = left.next()?;
1995
1996    let c = right.next()?;
1997    if right.next()? != b {
1998        return None;
1999    }
2000
2001    let mut batch_size = 1;
2002    loop {
2003        match (left.next(), right.next()) {
2004            (Some(l), Some(r)) if l == r => {
2005                batch_size *= l;
2006            }
2007            (None, None) => break,
2008            _ => return None,
2009        }
2010    }
2011
2012    Some([batch_size, a, b, c])
2013}
2014
2015#[inline]
2016fn permute_for_reduce<'a, T, A, P>(
2017    platform: P,
2018    access: A,
2019    shape: Shape,
2020    axes: Axes,
2021) -> Result<Accessor<'a, T>, Error>
2022where
2023    T: Number,
2024    A: Access<T>,
2025    P: Transform<A, T>,
2026    Accessor<'a, T>: From<A> + From<AccessOp<P::Transpose, P>>,
2027{
2028    let mut permutation = Axes::with_capacity(shape.len());
2029    permutation.extend((0..shape.len()).filter(|x| !axes.contains(x)));
2030    permutation.extend(axes);
2031
2032    if permutation.iter().copied().enumerate().all(|(i, x)| i == x) {
2033        Ok(Accessor::from(access))
2034    } else {
2035        platform
2036            .transpose(access, shape, permutation)
2037            .map(Accessor::from)
2038    }
2039}
2040
2041#[inline]
2042fn reduce_axes(shape: &[usize], axes: &[usize], keepdims: bool) -> Result<Shape, Error> {
2043    let mut shape = Shape::from_slice(shape);
2044
2045    for x in axes.iter().copied().rev() {
2046        if x >= shape.len() {
2047            return Err(Error::bounds(format!(
2048                "axis {x} is out of bounds for {shape:?}"
2049            )));
2050        } else if keepdims {
2051            shape[x] = 1;
2052        } else {
2053            shape.remove(x);
2054        }
2055    }
2056
2057    if shape.is_empty() {
2058        Ok(shape![1])
2059    } else {
2060        Ok(shape)
2061    }
2062}
2063
2064#[inline]
2065pub fn same_shape(op_name: &'static str, left: &[usize], right: &[usize]) -> Result<(), Error> {
2066    if left == right {
2067        Ok(())
2068    } else if can_broadcast(left, right) {
2069        Err(Error::bounds(format!(
2070            "cannot {op_name} arrays with shapes {left:?} and {right:?} (consider broadcasting)"
2071        )))
2072    } else {
2073        Err(Error::bounds(format!(
2074            "cannot {op_name} arrays with shapes {left:?} and {right:?}"
2075        )))
2076    }
2077}
2078
2079#[inline]
2080fn valid_coord(coord: &[usize], shape: &[usize]) -> Result<(), Error> {
2081    if coord.len() == shape.len() && coord.iter().zip(shape).all(|(i, dim)| i < dim) {
2082        return Ok(());
2083    }
2084
2085    Err(Error::bounds(format!(
2086        "invalid coordinate {coord:?} for shape {shape:?}"
2087    )))
2088}