ha_ndarray/
array.rs

1use std::borrow::{Borrow, BorrowMut};
2use std::fmt;
3use std::marker::PhantomData;
4
5use crate::access::*;
6use crate::buffer::BufferInstance;
7use crate::ops::*;
8use crate::platform::PlatformInstance;
9use crate::{
10    range_shape, shape, strides_for, Axes, AxisRange, BufferConverter, CType, Constant, Convert,
11    Error, Float, Platform, Range, Shape,
12};
13
14pub struct Array<T, A, P> {
15    shape: Shape,
16    access: A,
17    platform: P,
18    dtype: PhantomData<T>,
19}
20
21impl<T, A: Clone, P: Clone> Clone for Array<T, A, P> {
22    fn clone(&self) -> Self {
23        Self {
24            shape: self.shape.clone(),
25            access: self.access.clone(),
26            platform: self.platform.clone(),
27            dtype: self.dtype,
28        }
29    }
30}
31
32impl<T, A, P> Array<T, A, P> {
33    fn apply<O, OT, Op>(self, op: Op) -> Result<Array<OT, AccessOp<O, P>, P>, Error>
34    where
35        P: Copy,
36        Op: Fn(P, A) -> Result<AccessOp<O, P>, Error>,
37    {
38        let access = (op)(self.platform, self.access)?;
39
40        Ok(Array {
41            shape: self.shape,
42            access,
43            platform: self.platform,
44            dtype: PhantomData,
45        })
46    }
47
48    fn reduce_axes<Op>(
49        self,
50        mut axes: Axes,
51        keepdims: bool,
52        op: Op,
53    ) -> Result<Array<T, AccessOp<P::Op, P>, P>, Error>
54    where
55        T: CType,
56        A: Access<T>,
57        P: Transform<A, T> + ReduceAxes<Accessor<T>, T>,
58        Op: Fn(P, Accessor<T>, usize) -> Result<AccessOp<P::Op, P>, Error>,
59        Accessor<T>: From<A> + From<AccessOp<P::Transpose, P>>,
60    {
61        axes.sort();
62        axes.dedup();
63
64        let shape = reduce_axes(&self.shape, &axes, keepdims)?;
65        let size = shape.iter().product::<usize>();
66        let stride = axes.iter().copied().map(|x| self.shape[x]).product();
67        let platform = P::select(size);
68
69        let access = permute_for_reduce(self.platform, self.access, self.shape, axes)?;
70        let access = (op)(self.platform, access, stride)?;
71
72        Ok(Array {
73            access,
74            shape,
75            platform,
76            dtype: PhantomData,
77        })
78    }
79
80    pub fn access(&self) -> &A {
81        &self.access
82    }
83
84    pub fn into_access(self) -> A {
85        self.access
86    }
87}
88
89impl<T, L, P> Array<T, L, P> {
90    fn apply_dual<O, OT, R, Op>(
91        self,
92        other: Array<T, R, P>,
93        op: Op,
94    ) -> Result<Array<OT, AccessOp<O, P>, P>, Error>
95    where
96        P: Copy,
97        Op: Fn(P, L, R) -> Result<AccessOp<O, P>, Error>,
98    {
99        let access = (op)(self.platform, self.access, other.access)?;
100
101        Ok(Array {
102            shape: self.shape,
103            access,
104            platform: self.platform,
105            dtype: PhantomData,
106        })
107    }
108}
109
110// constructors
111impl<T: CType> Array<T, Accessor<T>, Platform> {
112    pub fn from<A, P>(array: Array<T, A, P>) -> Self
113    where
114        Accessor<T>: From<A>,
115        Platform: From<P>,
116    {
117        Self {
118            shape: array.shape,
119            access: array.access.into(),
120            platform: array.platform.into(),
121            dtype: array.dtype,
122        }
123    }
124}
125
126impl<T, B, P> Array<T, AccessBuf<B>, P>
127where
128    T: CType,
129    B: BufferInstance<T>,
130    P: PlatformInstance,
131{
132    fn new_inner(platform: P, buffer: B, shape: Shape) -> Result<Self, Error> {
133        if !shape.is_empty() && shape.iter().product::<usize>() == buffer.len() {
134            let access = buffer.into();
135
136            Ok(Self {
137                shape,
138                access,
139                platform,
140                dtype: PhantomData,
141            })
142        } else {
143            Err(Error::Bounds(format!(
144                "cannot construct an array with shape {shape:?} from a buffer of size {}",
145                buffer.len(),
146            )))
147        }
148    }
149
150    pub fn convert<'a, FB>(buffer: FB, shape: Shape) -> Result<Self, Error>
151    where
152        FB: Into<BufferConverter<'a, T>>,
153        P: Convert<T, Buffer = B>,
154    {
155        let buffer = buffer.into();
156        let platform = P::select(buffer.len());
157        let buffer = platform.convert(buffer)?;
158        Self::new_inner(platform, buffer, shape)
159    }
160
161    pub fn new(buffer: B, shape: Shape) -> Result<Self, Error> {
162        let platform = P::select(buffer.len());
163        Self::new_inner(platform, buffer, shape)
164    }
165}
166
167impl<T, P> Array<T, AccessBuf<P::Buffer>, P>
168where
169    T: CType,
170    P: Constant<T>,
171{
172    pub fn constant(value: T, shape: Shape) -> Result<Self, Error> {
173        if !shape.is_empty() {
174            let size = shape.iter().product();
175            let platform = P::select(size);
176            let buffer = platform.constant(value, size)?;
177            let access = buffer.into();
178
179            Ok(Self {
180                shape,
181                access,
182                platform,
183                dtype: PhantomData,
184            })
185        } else {
186            Err(Error::Bounds(
187                "cannot construct an array with an empty shape".to_string(),
188            ))
189        }
190    }
191}
192
193impl<T, P> Array<T, AccessBuf<P::Buffer>, P>
194where
195    T: CType,
196    P: Convert<T>,
197{
198    pub fn copy<A: Access<T>>(source: &Array<T, A, P>) -> Result<Self, Error> {
199        let buffer = source
200            .buffer()
201            .and_then(|buf| source.platform.convert(buf))?;
202
203        Ok(Self {
204            shape: source.shape.clone(),
205            access: buffer.into(),
206            platform: source.platform,
207            dtype: source.dtype,
208        })
209    }
210}
211
212// op constructors
213impl<T: CType, P: PlatformInstance> Array<T, AccessOp<P::Range, P>, P>
214where
215    P: Construct<T>,
216{
217    pub fn range(start: T, stop: T, shape: Shape) -> Result<Self, Error> {
218        let size = shape.iter().product();
219        let platform = P::select(size);
220
221        platform.range(start, stop, size).map(|access| Self {
222            shape,
223            access,
224            platform,
225            dtype: PhantomData,
226        })
227    }
228}
229
230impl<P: PlatformInstance> Array<f32, AccessOp<P::Normal, P>, P>
231where
232    P: Random,
233{
234    pub fn random_normal(size: usize) -> Result<Self, Error> {
235        let platform = P::select(size);
236        let shape = shape![size];
237
238        platform.random_normal(size).map(|access| Self {
239            shape,
240            access,
241            platform,
242            dtype: PhantomData,
243        })
244    }
245}
246
247impl<P: PlatformInstance> Array<f32, AccessOp<P::Uniform, P>, P>
248where
249    P: Random,
250{
251    pub fn random_uniform(size: usize) -> Result<Self, Error> {
252        let platform = P::select(size);
253        let shape = shape![size];
254
255        platform.random_uniform(size).map(|access| Self {
256            shape,
257            access,
258            platform,
259            dtype: PhantomData,
260        })
261    }
262}
263
264// references
265impl<T, B, P> Array<T, AccessBuf<B>, P>
266where
267    T: CType,
268    B: BufferInstance<T>,
269    P: PlatformInstance,
270{
271    pub fn as_mut<RB: ?Sized>(&mut self) -> Array<T, AccessBuf<&mut RB>, P>
272    where
273        B: BorrowMut<RB>,
274    {
275        Array {
276            shape: Shape::from_slice(&self.shape),
277            access: self.access.as_mut(),
278            platform: self.platform,
279            dtype: PhantomData,
280        }
281    }
282
283    pub fn as_ref<RB: ?Sized>(&self) -> Array<T, AccessBuf<&RB>, P>
284    where
285        B: Borrow<RB>,
286    {
287        Array {
288            shape: Shape::from_slice(&self.shape),
289            access: self.access.as_ref(),
290            platform: self.platform,
291            dtype: PhantomData,
292        }
293    }
294}
295
296impl<T, O, P> Array<T, AccessOp<O, P>, P>
297where
298    T: CType,
299    O: Enqueue<P, T>,
300    P: PlatformInstance,
301{
302    pub fn as_mut<'a>(&'a mut self) -> Array<T, &'a mut AccessOp<O, P>, P>
303    where
304        O: Write<P, T>,
305    {
306        Array {
307            shape: Shape::from_slice(&self.shape),
308            access: &mut self.access,
309            platform: self.platform,
310            dtype: PhantomData,
311        }
312    }
313
314    pub fn as_ref(&self) -> Array<T, &AccessOp<O, P>, P> {
315        Array {
316            shape: Shape::from_slice(&self.shape),
317            access: &self.access,
318            platform: self.platform,
319            dtype: PhantomData,
320        }
321    }
322}
323
324// traits
325
326/// An n-dimensional array
327pub trait NDArray: Send + Sync {
328    /// The data type of the elements in this array
329    type DType: CType;
330
331    /// The platform used to construct operations on this array.
332    type Platform: PlatformInstance;
333
334    /// Return the number of dimensions in this array.
335    fn ndim(&self) -> usize {
336        self.shape().len()
337    }
338
339    /// Return the number of elements in this array.
340    fn size(&self) -> usize {
341        self.shape().iter().product()
342    }
343
344    /// Borrow the shape of this array.
345    fn shape(&self) -> &[usize];
346}
347
348impl<T, A, P> NDArray for Array<T, A, P>
349where
350    T: CType,
351    A: Access<T>,
352    P: PlatformInstance,
353{
354    type DType = T;
355    type Platform = P;
356
357    fn shape(&self) -> &[usize] {
358        &self.shape
359    }
360}
361
362/// Access methods for an [`NDArray`]
363pub trait NDArrayRead: NDArray + fmt::Debug + Sized {
364    /// Read the value of this [`NDArray`] into a [`BufferConverter`].
365    fn buffer(&self) -> Result<BufferConverter<Self::DType>, Error>;
366
367    /// Buffer this [`NDArray`] into a new, owned array, allocating only if needed.
368    fn into_read(
369        self,
370    ) -> Result<
371        Array<
372            Self::DType,
373            AccessBuf<<Self::Platform as Convert<Self::DType>>::Buffer>,
374            Self::Platform,
375        >,
376        Error,
377    >
378    where
379        Self::Platform: Convert<Self::DType>;
380
381    /// Read the value at a specific `coord` in this [`NDArray`].
382    fn read_value(&self, coord: &[usize]) -> Result<Self::DType, Error>;
383}
384
385impl<T, A, P> NDArrayRead for Array<T, A, P>
386where
387    T: CType,
388    A: Access<T>,
389    P: PlatformInstance,
390{
391    fn buffer(&self) -> Result<BufferConverter<T>, Error> {
392        self.access.read()
393    }
394
395    fn into_read(self) -> Result<Array<Self::DType, AccessBuf<P::Buffer>, Self::Platform>, Error>
396    where
397        P: Convert<T>,
398    {
399        let buffer = self.buffer().and_then(|buf| self.platform.convert(buf))?;
400        debug_assert_eq!(buffer.len(), self.size());
401
402        Ok(Array {
403            shape: self.shape,
404            access: buffer.into(),
405            platform: self.platform,
406            dtype: self.dtype,
407        })
408    }
409
410    fn read_value(&self, coord: &[usize]) -> Result<T, Error> {
411        valid_coord(coord, self.shape())?;
412
413        let strides = strides_for(self.shape(), self.ndim());
414
415        let offset = coord
416            .iter()
417            .zip(strides)
418            .map(|(i, stride)| i * stride)
419            .sum();
420
421        self.access.read_value(offset)
422    }
423}
424
425/// Access methods for a mutable [`NDArray`]
426pub trait NDArrayWrite: NDArray + fmt::Debug + Sized {
427    /// Overwrite this [`NDArray`] with the value of the `other` array.
428    fn write<O: NDArrayRead<DType = Self::DType>>(&mut self, other: &O) -> Result<(), Error>;
429
430    /// Overwrite this [`NDArray`] with a constant scalar `value`.
431    fn write_value(&mut self, value: Self::DType) -> Result<(), Error>;
432
433    /// Write the given `value` at the given `coord` of this [`NDArray`].
434    fn write_value_at(&mut self, coord: &[usize], value: Self::DType) -> Result<(), Error>;
435}
436
437// write ops
438impl<T, A, P> NDArrayWrite for Array<T, A, P>
439where
440    T: CType,
441    A: AccessMut<T>,
442    P: PlatformInstance,
443{
444    fn write<O>(&mut self, other: &O) -> Result<(), Error>
445    where
446        O: NDArrayRead<DType = Self::DType>,
447    {
448        same_shape("write", self.shape(), other.shape())?;
449        other.buffer().and_then(|buf| self.access.write(buf))
450    }
451
452    fn write_value(&mut self, value: Self::DType) -> Result<(), Error> {
453        self.access.write_value(value)
454    }
455
456    fn write_value_at(&mut self, coord: &[usize], value: Self::DType) -> Result<(), Error> {
457        valid_coord(coord, self.shape())?;
458
459        let offset = coord
460            .iter()
461            .zip(strides_for(self.shape(), self.ndim()))
462            .map(|(i, stride)| i * stride)
463            .sum();
464
465        self.access.write_value_at(offset, value)
466    }
467}
468
469// op traits
470
471/// Array cast operations
472pub trait NDArrayCast<OT: CType>: NDArray + Sized {
473    type Output: Access<OT>;
474
475    /// Construct a new array cast operation.
476    fn cast(self) -> Result<Array<OT, Self::Output, Self::Platform>, Error>;
477}
478
479impl<IT, OT, A, P> NDArrayCast<OT> for Array<IT, A, P>
480where
481    IT: CType,
482    OT: CType,
483    A: Access<IT>,
484    P: ElementwiseCast<A, IT, OT>,
485{
486    type Output = AccessOp<P::Op, P>;
487
488    fn cast(self) -> Result<Array<OT, AccessOp<P::Op, P>, P>, Error> {
489        Ok(Array {
490            shape: self.shape,
491            access: self.platform.cast(self.access)?,
492            platform: self.platform,
493            dtype: PhantomData,
494        })
495    }
496}
497
498/// Axis-wise array reduce operations
499pub trait NDArrayReduce: NDArray + fmt::Debug {
500    type Output: Access<Self::DType>;
501
502    /// Construct a max-reduce operation over the given `axes`.
503    fn max(
504        self,
505        axes: Axes,
506        keepdims: bool,
507    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
508
509    /// Construct a min-reduce operation over the given `axes`.
510    fn min(
511        self,
512        axes: Axes,
513        keepdims: bool,
514    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
515
516    /// Construct a product-reduce operation over the given `axes`.
517    fn product(
518        self,
519        axes: Axes,
520        keepdims: bool,
521    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
522
523    /// Construct a sum-reduce operation over the given `axes`.
524    fn sum(
525        self,
526        axes: Axes,
527        keepdims: bool,
528    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
529}
530
531impl<T, A, P> NDArrayReduce for Array<T, A, P>
532where
533    T: CType,
534    A: Access<T>,
535    P: Transform<A, T> + ReduceAxes<Accessor<T>, T>,
536    Accessor<T>: From<A> + From<AccessOp<P::Transpose, P>>,
537{
538    type Output = AccessOp<P::Op, P>;
539
540    fn max(
541        self,
542        axes: Axes,
543        keepdims: bool,
544    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
545        self.reduce_axes(axes, keepdims, |platform, access, stride| {
546            ReduceAxes::max(platform, access, stride)
547        })
548    }
549
550    fn min(
551        self,
552        axes: Axes,
553        keepdims: bool,
554    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
555        self.reduce_axes(axes, keepdims, |platform, access, stride| {
556            ReduceAxes::min(platform, access, stride)
557        })
558    }
559
560    fn product(
561        self,
562        axes: Axes,
563        keepdims: bool,
564    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
565        self.reduce_axes(axes, keepdims, |platform, access, stride| {
566            ReduceAxes::product(platform, access, stride)
567        })
568    }
569
570    fn sum(
571        self,
572        axes: Axes,
573        keepdims: bool,
574    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
575        self.reduce_axes(axes, keepdims, |platform, access, stride| {
576            ReduceAxes::sum(platform, access, stride)
577        })
578    }
579}
580
581/// Array transform operations
582pub trait NDArrayTransform: NDArray + Sized + fmt::Debug {
583    /// The type returned by `broadcast`
584    type Broadcast: Access<Self::DType>;
585
586    /// The type returned by `slice`
587    type Slice: Access<Self::DType>;
588
589    /// The type returned by `transpose`
590    type Transpose: Access<Self::DType>;
591
592    /// Broadcast this array into the given `shape`.
593    fn broadcast(
594        self,
595        shape: Shape,
596    ) -> Result<Array<Self::DType, Self::Broadcast, Self::Platform>, Error>;
597
598    /// Reshape this `array`.
599    fn reshape(self, shape: Shape) -> Result<Self, Error>;
600
601    /// Construct a slice of this array.
602    fn slice(self, range: Range) -> Result<Array<Self::DType, Self::Slice, Self::Platform>, Error>;
603
604    /// Contract the given `axes` of this array.
605    /// This will return an error if any of the `axes` have dimension > 1.
606    fn squeeze(self, axes: Axes) -> Result<Self, Error>;
607
608    /// Expand the given `axes` of this array.
609    fn unsqueeze(self, axes: Axes) -> Result<Self, Error>;
610
611    /// Transpose this array according to the given `permutation`.
612    /// If no permutation is given, the array axes will be reversed.
613    fn transpose(
614        self,
615        permutation: Option<Axes>,
616    ) -> Result<Array<Self::DType, Self::Transpose, Self::Platform>, Error>;
617}
618
619impl<T, A, P> NDArrayTransform for Array<T, A, P>
620where
621    T: CType,
622    A: Access<T>,
623    P: Transform<A, T>,
624{
625    type Broadcast = AccessOp<P::Broadcast, P>;
626    type Slice = AccessOp<P::Slice, P>;
627    type Transpose = AccessOp<P::Transpose, P>;
628
629    fn broadcast(self, shape: Shape) -> Result<Array<T, AccessOp<P::Broadcast, P>, P>, Error> {
630        if !can_broadcast(self.shape(), &shape) {
631            return Err(Error::Bounds(format!(
632                "cannot broadcast {self:?} into {shape:?}"
633            )));
634        }
635
636        let platform = P::select(shape.iter().product());
637        let broadcast = Shape::from_slice(&shape);
638        let access = platform.broadcast(self.access, self.shape, broadcast)?;
639
640        Ok(Array {
641            shape,
642            access,
643            platform,
644            dtype: self.dtype,
645        })
646    }
647
648    fn reshape(mut self, shape: Shape) -> Result<Self, Error> {
649        if shape.iter().product::<usize>() == self.size() {
650            self.shape = shape;
651            Ok(self)
652        } else {
653            Err(Error::Bounds(format!(
654                "cannot reshape an array with shape {:?} into {shape:?}",
655                self.shape
656            )))
657        }
658    }
659
660    fn slice(self, mut range: Range) -> Result<Array<T, AccessOp<P::Slice, P>, P>, Error> {
661        for (dim, range) in self.shape.iter().zip(&range) {
662            match range {
663                AxisRange::At(i) if i < dim => Ok(()),
664                AxisRange::In(start, stop, _step) if start < dim && stop <= dim => Ok(()),
665                AxisRange::Of(indices) if indices.iter().all(|i| i < dim) => Ok(()),
666                range => Err(Error::Bounds(format!(
667                    "invalid range {range:?} for dimension {dim}"
668                ))),
669            }?;
670        }
671
672        for dim in self.shape.iter().skip(range.len()).copied() {
673            range.push(AxisRange::In(0, dim, 1));
674        }
675
676        let shape = range_shape(self.shape(), &range);
677        let access = self.platform.slice(self.access, &self.shape, range)?;
678        let platform = P::select(shape.iter().product());
679
680        Ok(Array {
681            shape,
682            access,
683            platform,
684            dtype: self.dtype,
685        })
686    }
687
688    fn squeeze(mut self, mut axes: Axes) -> Result<Self, Error> {
689        if axes.iter().copied().any(|x| x >= self.ndim()) {
690            return Err(Error::Bounds(format!("invalid contraction axes: {axes:?}")));
691        }
692
693        axes.sort();
694
695        for x in axes.into_iter().rev() {
696            self.shape.remove(x);
697        }
698
699        Ok(self)
700    }
701
702    fn unsqueeze(mut self, mut axes: Axes) -> Result<Self, Error> {
703        if axes.iter().copied().any(|x| x > self.ndim()) {
704            return Err(Error::Bounds(format!("invalid expansion axes: {axes:?}")));
705        }
706
707        axes.sort();
708
709        for x in axes.into_iter().rev() {
710            self.shape.insert(x, 1);
711        }
712
713        Ok(self)
714    }
715
716    fn transpose(
717        self,
718        permutation: Option<Axes>,
719    ) -> Result<Array<T, AccessOp<P::Transpose, P>, P>, Error> {
720        let permutation = if let Some(axes) = permutation {
721            if axes.len() == self.ndim()
722                && axes.iter().copied().all(|x| x < self.ndim())
723                && !(1..axes.len())
724                    .into_iter()
725                    .any(|i| axes[i..].contains(&axes[i - 1]))
726            {
727                Ok(axes)
728            } else {
729                Err(Error::Bounds(format!(
730                    "invalid permutation for shape {:?}: {:?}",
731                    self.shape, axes
732                )))
733            }
734        } else {
735            Ok((0..self.ndim()).into_iter().rev().collect())
736        }?;
737
738        let shape = permutation.iter().copied().map(|x| self.shape[x]).collect();
739        let platform = self.platform;
740        let access = platform.transpose(self.access, self.shape, permutation)?;
741
742        Ok(Array {
743            shape,
744            access,
745            platform,
746            dtype: self.dtype,
747        })
748    }
749}
750
751/// Unary array operations
752pub trait NDArrayUnary: NDArray + Sized {
753    /// The return type of a unary operation.
754    type Output: Access<Self::DType>;
755
756    /// Construct an absolute value operation.
757    fn abs(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
758
759    /// Construct an exponentiation operation.
760    fn exp(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
761
762    /// Construct a natural logarithm operation.
763    fn ln(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
764
765    /// Construct an integer rounding operation.
766    fn round(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
767}
768
769impl<T, A, P> NDArrayUnary for Array<T, A, P>
770where
771    T: CType,
772    A: Access<T>,
773    P: ElementwiseUnary<A, T>,
774{
775    type Output = AccessOp<P::Op, P>;
776
777    fn abs(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
778        self.apply(|platform, access| platform.abs(access))
779    }
780
781    fn exp(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
782        self.apply(|platform, access| platform.exp(access))
783    }
784
785    fn ln(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
786    where
787        P: ElementwiseUnary<A, T>,
788    {
789        self.apply(|platform, access| platform.ln(access))
790    }
791
792    fn round(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
793        self.apply(|platform, access| platform.round(access))
794    }
795}
796
797/// Unary boolean array operations
798pub trait NDArrayUnaryBoolean: NDArray + Sized {
799    /// The return type of a unary operation.
800    type Output: Access<u8>;
801
802    /// Construct a boolean not operation.
803    fn not(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
804}
805
806impl<T, A, P> NDArrayUnaryBoolean for Array<T, A, P>
807where
808    T: CType,
809    A: Access<T>,
810    P: ElementwiseUnaryBoolean<A, T>,
811{
812    type Output = AccessOp<P::Op, P>;
813
814    fn not(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
815        self.apply(|platform, access| platform.not(access))
816    }
817}
818
819/// Boolean array operations
820pub trait NDArrayBoolean<O>: NDArray + Sized
821where
822    O: NDArray<DType = Self::DType>,
823{
824    type Output: Access<u8>;
825
826    /// Construct a boolean and comparison with the `other` array.
827    fn and(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
828
829    /// Construct a boolean or comparison with the `other` array.
830    fn or(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
831
832    /// Construct a boolean xor comparison with the `other` array.
833    fn xor(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
834}
835
836impl<T, L, R, P> NDArrayBoolean<Array<T, R, P>> for Array<T, L, P>
837where
838    T: CType,
839    L: Access<T>,
840    R: Access<T>,
841    P: ElementwiseBoolean<L, R, T>,
842{
843    type Output = AccessOp<P::Op, P>;
844
845    fn and(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
846        same_shape("and", self.shape(), other.shape())?;
847        self.apply_dual(other, |platform, left, right| platform.and(left, right))
848    }
849
850    fn or(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
851        same_shape("or", self.shape(), other.shape())?;
852        self.apply_dual(other, |platform, left, right| platform.or(left, right))
853    }
854
855    fn xor(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
856        same_shape("xor", self.shape(), other.shape())?;
857        self.apply_dual(other, |platform, left, right| platform.xor(left, right))
858    }
859}
860
861/// Boolean array operations with a scalar argument
862pub trait NDArrayBooleanScalar: NDArray + Sized {
863    type Output: Access<u8>;
864
865    /// Construct a boolean and operation with the `other` value.
866    fn and_scalar(
867        self,
868        other: Self::DType,
869    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
870
871    /// Construct a boolean or operation with the `other` value.
872    fn or_scalar(
873        self,
874        other: Self::DType,
875    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
876
877    /// Construct a boolean xor operation with the `other` value.
878    fn xor_scalar(
879        self,
880        other: Self::DType,
881    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
882}
883
884impl<T, A, P> NDArrayBooleanScalar for Array<T, A, P>
885where
886    T: CType,
887    A: Access<T>,
888    P: ElementwiseBooleanScalar<A, T>,
889{
890    type Output = AccessOp<P::Op, P>;
891
892    fn and_scalar(
893        self,
894        other: Self::DType,
895    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
896        self.apply(|platform, access| platform.and_scalar(access, other))
897    }
898
899    fn or_scalar(
900        self,
901        other: Self::DType,
902    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
903        self.apply(|platform, access| platform.or_scalar(access, other))
904    }
905
906    fn xor_scalar(
907        self,
908        other: Self::DType,
909    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
910        self.apply(|platform, access| platform.xor_scalar(access, other))
911    }
912}
913
914/// Array comparison operations
915pub trait NDArrayCompare<O: NDArray<DType = Self::DType>>: NDArray + Sized {
916    type Output: Access<u8>;
917
918    /// Elementwise equality comparison
919    fn eq(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
920
921    /// Elementwise greater-than-or-equal comparison
922    fn ge(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
923
924    /// Elementwise greater-than comparison
925    fn gt(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
926
927    /// Elementwise less-than-or-equal comparison
928    fn le(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
929
930    /// Elementwise less-than comparison
931    fn lt(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
932
933    /// Elementwise not-equal comparison
934    fn ne(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
935}
936
937impl<T, L, R, P> NDArrayCompare<Array<T, R, P>> for Array<T, L, P>
938where
939    T: CType,
940    L: Access<T>,
941    R: Access<T>,
942    P: ElementwiseCompare<L, R, T>,
943{
944    type Output = AccessOp<P::Op, P>;
945
946    fn eq(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
947        same_shape("compare", self.shape(), other.shape())?;
948        self.apply_dual(other, |platform, left, right| platform.eq(left, right))
949    }
950
951    fn ge(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
952        same_shape("compare", self.shape(), other.shape())?;
953        self.apply_dual(other, |platform, left, right| platform.ge(left, right))
954    }
955
956    fn gt(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
957        same_shape("compare", self.shape(), other.shape())?;
958        self.apply_dual(other, |platform, left, right| platform.gt(left, right))
959    }
960
961    fn le(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
962        same_shape("compare", self.shape(), other.shape())?;
963        self.apply_dual(other, |platform, left, right| platform.le(left, right))
964    }
965
966    fn lt(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
967        same_shape("compare", self.shape(), other.shape())?;
968        self.apply_dual(other, |platform, left, right| platform.lt(left, right))
969    }
970
971    fn ne(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
972        same_shape("compare", self.shape(), other.shape())?;
973        self.apply_dual(other, |platform, left, right| platform.ne(left, right))
974    }
975}
976
977/// Array-scalar comparison operations
978pub trait NDArrayCompareScalar: NDArray + Sized {
979    type Output: Access<u8>;
980
981    /// Construct an equality comparison with the `other` value.
982    fn eq_scalar(
983        self,
984        other: Self::DType,
985    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
986
987    /// Construct a greater-than comparison with the `other` value.
988    fn gt_scalar(
989        self,
990        other: Self::DType,
991    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
992
993    /// Construct an equal-or-greater-than comparison with the `other` value.
994    fn ge_scalar(
995        self,
996        other: Self::DType,
997    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
998
999    /// Construct a less-than comparison with the `other` value.
1000    fn lt_scalar(
1001        self,
1002        other: Self::DType,
1003    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1004
1005    /// Construct an equal-or-less-than comparison with the `other` value.
1006    fn le_scalar(
1007        self,
1008        other: Self::DType,
1009    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1010
1011    /// Construct an not-equal comparison with the `other` value.
1012    fn ne_scalar(
1013        self,
1014        other: Self::DType,
1015    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1016}
1017
1018impl<T, A, P> NDArrayCompareScalar for Array<T, A, P>
1019where
1020    T: CType,
1021    A: Access<T>,
1022    P: ElementwiseScalarCompare<A, T>,
1023{
1024    type Output = AccessOp<P::Op, P>;
1025
1026    fn eq_scalar(
1027        self,
1028        other: Self::DType,
1029    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1030        self.apply(|platform, access| platform.eq_scalar(access, other))
1031    }
1032
1033    fn gt_scalar(
1034        self,
1035        other: Self::DType,
1036    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1037        self.apply(|platform, access| platform.gt_scalar(access, other))
1038    }
1039
1040    fn ge_scalar(
1041        self,
1042        other: Self::DType,
1043    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1044        self.apply(|platform, access| platform.ge_scalar(access, other))
1045    }
1046
1047    fn lt_scalar(
1048        self,
1049        other: Self::DType,
1050    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1051        self.apply(|platform, access| platform.lt_scalar(access, other))
1052    }
1053
1054    fn le_scalar(
1055        self,
1056        other: Self::DType,
1057    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1058        self.apply(|platform, access| platform.le_scalar(access, other))
1059    }
1060
1061    fn ne_scalar(
1062        self,
1063        other: Self::DType,
1064    ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1065        self.apply(|platform, access| platform.ne_scalar(access, other))
1066    }
1067}
1068
1069/// Array arithmetic operations
1070pub trait NDArrayMath<O: NDArray<DType = Self::DType>>: NDArray + Sized {
1071    type Output: Access<Self::DType>;
1072
1073    /// Construct an addition operation with the given `rhs`.
1074    fn add(self, rhs: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1075
1076    /// Construct a division operation with the given `rhs`.
1077    fn div(self, rhs: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1078
1079    /// Construct a logarithm operation with the given `base`.
1080    fn log(self, base: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1081
1082    /// Construct a multiplication operation with the given `rhs`.
1083    fn mul(self, rhs: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1084
1085    /// Construct an operation to raise these data to the power of the given `exp`onent.
1086    fn pow(self, exp: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1087
1088    /// Construct an array subtraction operation with the given `rhs`.
1089    fn sub(self, rhs: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1090
1091    /// Construct a modulo operation with the given `rhs`.
1092    fn rem(self, rhs: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1093}
1094
1095impl<T, L, R, P> NDArrayMath<Array<T, R, P>> for Array<T, L, P>
1096where
1097    T: CType,
1098    L: Access<T>,
1099    R: Access<T>,
1100    P: ElementwiseDual<L, R, T>,
1101{
1102    type Output = AccessOp<P::Op, P>;
1103
1104    fn add(
1105        self,
1106        rhs: Array<T, R, P>,
1107    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1108        same_shape("add", self.shape(), rhs.shape())?;
1109        self.apply_dual(rhs, |platform, left, right| platform.add(left, right))
1110    }
1111
1112    fn div(
1113        self,
1114        rhs: Array<T, R, P>,
1115    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1116        same_shape("div", self.shape(), rhs.shape())?;
1117        self.apply_dual(rhs, |platform, left, right| platform.div(left, right))
1118    }
1119
1120    fn log(
1121        self,
1122        base: Array<T, R, P>,
1123    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1124        same_shape("log", self.shape(), base.shape())?;
1125        self.apply_dual(base, |platform, left, right| platform.log(left, right))
1126    }
1127
1128    fn mul(
1129        self,
1130        rhs: Array<T, R, P>,
1131    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1132        same_shape("mul", self.shape(), rhs.shape())?;
1133        self.apply_dual(rhs, |platform, left, right| platform.mul(left, right))
1134    }
1135
1136    fn pow(
1137        self,
1138        exp: Array<T, R, P>,
1139    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1140        same_shape("pow", self.shape(), exp.shape())?;
1141        self.apply_dual(exp, |platform, left, right| platform.pow(left, right))
1142    }
1143
1144    fn sub(
1145        self,
1146        rhs: Array<T, R, P>,
1147    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1148        same_shape("sub", self.shape(), rhs.shape())?;
1149        self.apply_dual(rhs, |platform, left, right| platform.sub(left, right))
1150    }
1151
1152    fn rem(
1153        self,
1154        rhs: Array<T, R, P>,
1155    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1156        same_shape("rem", self.shape(), rhs.shape())?;
1157        self.apply_dual(rhs, |platform, left, right| platform.rem(left, right))
1158    }
1159}
1160
1161/// Array arithmetic operations with a scalar argument
1162pub trait NDArrayMathScalar: NDArray + Sized {
1163    type Output: Access<Self::DType>;
1164
1165    /// Construct a scalar addition operation.
1166    fn add_scalar(
1167        self,
1168        rhs: Self::DType,
1169    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1170
1171    /// Construct a scalar division operation.
1172    fn div_scalar(
1173        self,
1174        rhs: Self::DType,
1175    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1176
1177    /// Construct a scalar logarithm operation.
1178    fn log_scalar(
1179        self,
1180        base: Self::DType,
1181    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1182
1183    /// Construct a scalar multiplication operation.
1184    fn mul_scalar(
1185        self,
1186        rhs: Self::DType,
1187    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1188
1189    /// Construct a scalar exponentiation operation.
1190    fn pow_scalar(
1191        self,
1192        exp: Self::DType,
1193    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1194
1195    /// Construct a scalar modulo operation.
1196    fn rem_scalar(
1197        self,
1198        rhs: Self::DType,
1199    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1200
1201    /// Construct a scalar subtraction operation.
1202    fn sub_scalar(
1203        self,
1204        rhs: Self::DType,
1205    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1206}
1207
1208impl<T, A, P> NDArrayMathScalar for Array<T, A, P>
1209where
1210    T: CType,
1211    A: Access<T>,
1212    P: ElementwiseScalar<A, T>,
1213{
1214    type Output = AccessOp<P::Op, P>;
1215
1216    fn add_scalar(
1217        self,
1218        rhs: Self::DType,
1219    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1220        self.apply(|platform, left| platform.add_scalar(left, rhs))
1221    }
1222
1223    fn div_scalar(
1224        self,
1225        rhs: Self::DType,
1226    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1227        if rhs != T::ZERO {
1228            self.apply(|platform, left| platform.div_scalar(left, rhs))
1229        } else {
1230            Err(Error::Unsupported(format!(
1231                "cannot divide {self:?} by {rhs}"
1232            )))
1233        }
1234    }
1235
1236    fn log_scalar(
1237        self,
1238        base: Self::DType,
1239    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1240        self.apply(|platform, arg| platform.log_scalar(arg, base))
1241    }
1242
1243    fn mul_scalar(
1244        self,
1245        rhs: Self::DType,
1246    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1247        self.apply(|platform, left| platform.mul_scalar(left, rhs))
1248    }
1249
1250    fn pow_scalar(
1251        self,
1252        exp: Self::DType,
1253    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1254        self.apply(|platform, arg| platform.pow_scalar(arg, exp))
1255    }
1256
1257    fn rem_scalar(
1258        self,
1259        rhs: Self::DType,
1260    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1261        self.apply(|platform, left| platform.rem_scalar(left, rhs))
1262    }
1263
1264    fn sub_scalar(
1265        self,
1266        rhs: Self::DType,
1267    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1268        self.apply(|platform, left| platform.sub_scalar(left, rhs))
1269    }
1270}
1271
1272/// Float-specific array methods
1273pub trait NDArrayNumeric: NDArray + Sized
1274where
1275    Self::DType: Float,
1276{
1277    type Output: Access<u8>;
1278
1279    /// Test which elements of this array are infinite.
1280    fn is_inf(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1281
1282    /// Test which elements of this array are not-a-number.
1283    fn is_nan(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1284}
1285
1286impl<T, A, P> NDArrayNumeric for Array<T, A, P>
1287where
1288    T: Float,
1289    A: Access<T>,
1290    P: ElementwiseNumeric<A, T>,
1291{
1292    type Output = AccessOp<P::Op, P>;
1293
1294    fn is_inf(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1295        self.apply(|platform, access| platform.is_inf(access))
1296    }
1297
1298    fn is_nan(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1299        self.apply(|platform, access| platform.is_nan(access))
1300    }
1301}
1302
1303/// Boolean array reduce operations
1304pub trait NDArrayReduceBoolean: NDArrayRead {
1305    /// Return `true` if this array contains only non-zero elements.
1306    fn all(self) -> Result<bool, Error>;
1307
1308    /// Return `true` if this array contains any non-zero elements.
1309    fn any(self) -> Result<bool, Error>;
1310}
1311
1312impl<T, A, P> NDArrayReduceBoolean for Array<T, A, P>
1313where
1314    T: CType,
1315    A: Access<T>,
1316    P: ReduceAll<A, T>,
1317{
1318    fn all(self) -> Result<bool, Error> {
1319        self.platform.all(self.access)
1320    }
1321
1322    fn any(self) -> Result<bool, Error> {
1323        self.platform.any(self.access)
1324    }
1325}
1326
1327/// Array reduce operations
1328pub trait NDArrayReduceAll: NDArrayRead {
1329    /// Return the maximum of all elements in this array.
1330    fn max_all(self) -> Result<Self::DType, Error>;
1331
1332    /// Return the minimum of all elements in this array.
1333    fn min_all(self) -> Result<Self::DType, Error>;
1334
1335    /// Return the product of all elements in this array.
1336    fn product_all(self) -> Result<Self::DType, Error>;
1337
1338    /// Return the sum of all elements in this array.
1339    fn sum_all(self) -> Result<Self::DType, Error>;
1340}
1341
1342impl<'a, T, A, P> NDArrayReduceAll for Array<T, A, P>
1343where
1344    T: CType,
1345    A: Access<T>,
1346    P: ReduceAll<A, T>,
1347{
1348    fn max_all(self) -> Result<Self::DType, Error> {
1349        self.platform.max(self.access)
1350    }
1351
1352    fn min_all(self) -> Result<Self::DType, Error> {
1353        self.platform.min(self.access)
1354    }
1355
1356    fn product_all(self) -> Result<Self::DType, Error> {
1357        self.platform.product(self.access)
1358    }
1359
1360    fn sum_all(self) -> Result<T, Error> {
1361        self.platform.sum(self.access)
1362    }
1363}
1364
1365impl<T, A, P> fmt::Debug for Array<T, A, P> {
1366    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1367        write!(
1368            f,
1369            "a {} array of shape {:?}",
1370            std::any::type_name::<T>(),
1371            self.shape
1372        )
1373    }
1374}
1375
1376/// Array trigonometry methods
1377pub trait NDArrayTrig: NDArray + Sized {
1378    type Output: Access<<Self::DType as CType>::Float>;
1379
1380    /// Construct a new sine operation.
1381    fn sin(
1382        self,
1383    ) -> Result<Array<<Self::DType as CType>::Float, Self::Output, Self::Platform>, Error>;
1384
1385    /// Construct a new arcsine operation.
1386    fn asin(
1387        self,
1388    ) -> Result<Array<<Self::DType as CType>::Float, Self::Output, Self::Platform>, Error>;
1389
1390    /// Construct a new hyperbolic sine operation.
1391    fn sinh(
1392        self,
1393    ) -> Result<Array<<Self::DType as CType>::Float, Self::Output, Self::Platform>, Error>;
1394
1395    /// Construct a new cos operation.
1396    fn cos(
1397        self,
1398    ) -> Result<Array<<Self::DType as CType>::Float, Self::Output, Self::Platform>, Error>;
1399
1400    /// Construct a new arccosine operation.
1401    fn acos(
1402        self,
1403    ) -> Result<Array<<Self::DType as CType>::Float, Self::Output, Self::Platform>, Error>;
1404
1405    /// Construct a new hyperbolic cosine operation.
1406    fn cosh(
1407        self,
1408    ) -> Result<Array<<Self::DType as CType>::Float, Self::Output, Self::Platform>, Error>;
1409
1410    /// Construct a new tangent operation.
1411    fn tan(
1412        self,
1413    ) -> Result<Array<<Self::DType as CType>::Float, Self::Output, Self::Platform>, Error>;
1414
1415    /// Construct a new arctangent operation.
1416    fn atan(
1417        self,
1418    ) -> Result<Array<<Self::DType as CType>::Float, Self::Output, Self::Platform>, Error>;
1419
1420    /// Construct a new hyperbolic tangent operation.
1421    fn tanh(
1422        self,
1423    ) -> Result<Array<<Self::DType as CType>::Float, Self::Output, Self::Platform>, Error>;
1424}
1425
1426impl<T, A, P> NDArrayTrig for Array<T, A, P>
1427where
1428    T: CType,
1429    A: Access<T>,
1430    P: ElementwiseTrig<A, T>,
1431{
1432    type Output = AccessOp<P::Op, P>;
1433
1434    fn sin(self) -> Result<Array<T::Float, Self::Output, Self::Platform>, Error> {
1435        self.apply(|platform, access| platform.sin(access))
1436    }
1437
1438    fn asin(self) -> Result<Array<T::Float, Self::Output, Self::Platform>, Error> {
1439        self.apply(|platform, access| platform.asin(access))
1440    }
1441
1442    fn sinh(self) -> Result<Array<T::Float, Self::Output, Self::Platform>, Error> {
1443        self.apply(|platform, access| platform.sinh(access))
1444    }
1445
1446    fn cos(self) -> Result<Array<T::Float, Self::Output, Self::Platform>, Error> {
1447        self.apply(|platform, access| platform.cos(access))
1448    }
1449
1450    fn acos(self) -> Result<Array<T::Float, Self::Output, Self::Platform>, Error> {
1451        self.apply(|platform, access| platform.acos(access))
1452    }
1453
1454    fn cosh(self) -> Result<Array<T::Float, Self::Output, Self::Platform>, Error> {
1455        self.apply(|platform, access| platform.cosh(access))
1456    }
1457
1458    fn tan(self) -> Result<Array<T::Float, Self::Output, Self::Platform>, Error> {
1459        self.apply(|platform, access| platform.tan(access))
1460    }
1461
1462    fn atan(self) -> Result<Array<T::Float, Self::Output, Self::Platform>, Error> {
1463        self.apply(|platform, access| platform.atan(access))
1464    }
1465
1466    fn tanh(self) -> Result<Array<T::Float, Self::Output, Self::Platform>, Error> {
1467        self.apply(|platform, access| platform.tanh(access))
1468    }
1469}
1470
1471/// Conditional selection (boolean logic) methods
1472pub trait NDArrayWhere<T, L, R>: NDArray<DType = u8> + fmt::Debug
1473where
1474    T: CType,
1475{
1476    type Output: Access<T>;
1477
1478    /// Construct a boolean selection operation.
1479    /// The resulting array will return values from `then` where `self` is `true`
1480    /// and from `or_else` where `self` is `false`.
1481    fn cond(self, then: L, or_else: R) -> Result<Array<T, Self::Output, Self::Platform>, Error>;
1482}
1483
1484impl<T, A, L, R, P> NDArrayWhere<T, Array<T, L, P>, Array<T, R, P>> for Array<u8, A, P>
1485where
1486    T: CType,
1487    A: Access<u8>,
1488    L: Access<T>,
1489    R: Access<T>,
1490    P: GatherCond<A, L, R, T>,
1491{
1492    type Output = AccessOp<P::Op, P>;
1493
1494    fn cond(
1495        self,
1496        then: Array<T, L, P>,
1497        or_else: Array<T, R, P>,
1498    ) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1499        same_shape("cond", self.shape(), then.shape())?;
1500        same_shape("cond", self.shape(), or_else.shape())?;
1501
1502        let access = self
1503            .platform
1504            .cond(self.access, then.access, or_else.access)?;
1505
1506        Ok(Array {
1507            shape: self.shape,
1508            access,
1509            platform: self.platform,
1510            dtype: PhantomData,
1511        })
1512    }
1513}
1514
1515/// Matrix dual operations
1516pub trait MatrixDual<O>: NDArray + fmt::Debug
1517where
1518    O: NDArray<DType = Self::DType> + fmt::Debug,
1519{
1520    type Output: Access<Self::DType>;
1521
1522    /// Construct an operation to multiply this matrix or batch of matrices with the `other`.
1523    fn matmul(self, other: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1524}
1525
1526impl<T, L, R, P> MatrixDual<Array<T, R, P>> for Array<T, L, P>
1527where
1528    T: CType,
1529    L: Access<T>,
1530    R: Access<T>,
1531    P: LinAlgDual<L, R, T>,
1532{
1533    type Output = AccessOp<P::Op, P>;
1534
1535    fn matmul(
1536        self,
1537        other: Array<T, R, P>,
1538    ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1539        let dims = matmul_dims(&self.shape, &other.shape).ok_or_else(|| {
1540            Error::Bounds(format!(
1541                "invalid dimensions for matrix multiply: {:?} and {:?}",
1542                self.shape, other.shape
1543            ))
1544        })?;
1545
1546        let mut shape = Shape::with_capacity(self.ndim());
1547        shape.extend(self.shape.iter().rev().skip(2).rev().copied());
1548        shape.push(dims[1]);
1549        shape.push(dims[3]);
1550
1551        let platform = P::select(dims.iter().product());
1552
1553        let access = platform.matmul(self.access, other.access, dims)?;
1554
1555        Ok(Array {
1556            shape,
1557            access,
1558            platform,
1559            dtype: self.dtype,
1560        })
1561    }
1562}
1563
1564/// Matrix unary operations
1565pub trait MatrixUnary: NDArray + fmt::Debug {
1566    type Diag: Access<Self::DType>;
1567
1568    /// Construct an operation to read the diagonal(s) of this matrix or batch of matrices.
1569    /// This will return an error if the last two dimensions of the batch are unequal.
1570    fn diag(self) -> Result<Array<Self::DType, Self::Diag, Self::Platform>, Error>;
1571}
1572
1573impl<T, A, P> MatrixUnary for Array<T, A, P>
1574where
1575    T: CType,
1576    A: Access<T>,
1577    P: LinAlgUnary<A, T>,
1578{
1579    type Diag = AccessOp<P::Op, P>;
1580
1581    fn diag(self) -> Result<Array<T, AccessOp<P::Op, P>, P>, Error> {
1582        if self.ndim() >= 2 && self.shape.last() == self.shape.iter().nth_back(1) {
1583            let batch_size = self.shape.iter().rev().skip(2).product();
1584            let dim = self.shape.last().copied().expect("dim");
1585
1586            let shape = self.shape.iter().rev().skip(1).rev().copied().collect();
1587            let platform = P::select(batch_size * dim * dim);
1588            let access = platform.diag(self.access, batch_size, dim)?;
1589
1590            Ok(Array {
1591                shape,
1592                access,
1593                platform,
1594                dtype: PhantomData,
1595            })
1596        } else {
1597            Err(Error::Bounds(format!(
1598                "invalid shape for diagonal: {:?}",
1599                self.shape
1600            )))
1601        }
1602    }
1603}
1604
1605#[inline]
1606fn can_broadcast(left: &[usize], right: &[usize]) -> bool {
1607    if left.len() < right.len() {
1608        return can_broadcast(right, left);
1609    }
1610
1611    for (l, r) in left.iter().copied().rev().zip(right.iter().copied().rev()) {
1612        if l == r || l == 1 || r == 1 {
1613            // pass
1614        } else {
1615            return false;
1616        }
1617    }
1618
1619    true
1620}
1621
1622#[inline]
1623fn matmul_dims(left: &[usize], right: &[usize]) -> Option<[usize; 4]> {
1624    let mut left = left.into_iter().copied().rev();
1625    let mut right = right.into_iter().copied().rev();
1626
1627    let b = left.next()?;
1628    let a = left.next()?;
1629
1630    let c = right.next()?;
1631    if right.next()? != b {
1632        return None;
1633    }
1634
1635    let mut batch_size = 1;
1636    loop {
1637        match (left.next(), right.next()) {
1638            (Some(l), Some(r)) if l == r => {
1639                batch_size *= l;
1640            }
1641            (None, None) => break,
1642            _ => return None,
1643        }
1644    }
1645
1646    Some([batch_size, a, b, c])
1647}
1648
1649#[inline]
1650fn permute_for_reduce<T, A, P>(
1651    platform: P,
1652    access: A,
1653    shape: Shape,
1654    axes: Axes,
1655) -> Result<Accessor<T>, Error>
1656where
1657    T: CType,
1658    A: Access<T>,
1659    P: Transform<A, T>,
1660    Accessor<T>: From<A> + From<AccessOp<P::Transpose, P>>,
1661{
1662    let mut permutation = Axes::with_capacity(shape.len());
1663    permutation.extend((0..shape.len()).into_iter().filter(|x| !axes.contains(x)));
1664    permutation.extend(axes);
1665
1666    if permutation.iter().copied().enumerate().all(|(i, x)| i == x) {
1667        Ok(Accessor::from(access))
1668    } else {
1669        platform
1670            .transpose(access, shape, permutation)
1671            .map(Accessor::from)
1672    }
1673}
1674
1675#[inline]
1676fn reduce_axes(shape: &[usize], axes: &[usize], keepdims: bool) -> Result<Shape, Error> {
1677    let mut shape = Shape::from_slice(shape);
1678
1679    for x in axes.iter().copied().rev() {
1680        if x >= shape.len() {
1681            return Err(Error::Bounds(format!(
1682                "axis {x} is out of bounds for {shape:?}"
1683            )));
1684        } else if keepdims {
1685            shape[x] = 1;
1686        } else {
1687            shape.remove(x);
1688        }
1689    }
1690
1691    if shape.is_empty() {
1692        Ok(shape![1])
1693    } else {
1694        Ok(shape)
1695    }
1696}
1697
1698#[inline]
1699fn same_shape(op_name: &'static str, left: &[usize], right: &[usize]) -> Result<(), Error> {
1700    if left == right {
1701        Ok(())
1702    } else if can_broadcast(left, right) {
1703        Err(Error::Bounds(format!(
1704            "cannot {op_name} arrays with shapes {left:?} and {right:?} (consider broadcasting)"
1705        )))
1706    } else {
1707        Err(Error::Bounds(format!(
1708            "cannot {op_name} arrays with shapes {left:?} and {right:?}"
1709        )))
1710    }
1711}
1712
1713#[inline]
1714fn valid_coord(coord: &[usize], shape: &[usize]) -> Result<(), Error> {
1715    if coord.len() == shape.len() {
1716        if coord.iter().zip(shape).all(|(i, dim)| i < dim) {
1717            return Ok(());
1718        }
1719    }
1720
1721    Err(Error::Bounds(format!(
1722        "invalid coordinate {coord:?} for shape {shape:?}"
1723    )))
1724}