1use std::borrow::{Borrow, BorrowMut};
2use std::fmt;
3use std::marker::PhantomData;
4
5use crate::access::*;
6use crate::buffer::BufferInstance;
7use crate::ops::*;
8use crate::platform::PlatformInstance;
9use crate::{
10 range_shape, shape, strides_for, Axes, AxisRange, BufferConverter, CType, Constant, Convert,
11 Error, Float, Platform, Range, Shape,
12};
13
14pub struct Array<T, A, P> {
15 shape: Shape,
16 access: A,
17 platform: P,
18 dtype: PhantomData<T>,
19}
20
21impl<T, A: Clone, P: Clone> Clone for Array<T, A, P> {
22 fn clone(&self) -> Self {
23 Self {
24 shape: self.shape.clone(),
25 access: self.access.clone(),
26 platform: self.platform.clone(),
27 dtype: self.dtype,
28 }
29 }
30}
31
32impl<T, A, P> Array<T, A, P> {
33 fn apply<O, OT, Op>(self, op: Op) -> Result<Array<OT, AccessOp<O, P>, P>, Error>
34 where
35 P: Copy,
36 Op: Fn(P, A) -> Result<AccessOp<O, P>, Error>,
37 {
38 let access = (op)(self.platform, self.access)?;
39
40 Ok(Array {
41 shape: self.shape,
42 access,
43 platform: self.platform,
44 dtype: PhantomData,
45 })
46 }
47
48 fn reduce_axes<Op>(
49 self,
50 mut axes: Axes,
51 keepdims: bool,
52 op: Op,
53 ) -> Result<Array<T, AccessOp<P::Op, P>, P>, Error>
54 where
55 T: CType,
56 A: Access<T>,
57 P: Transform<A, T> + ReduceAxes<Accessor<T>, T>,
58 Op: Fn(P, Accessor<T>, usize) -> Result<AccessOp<P::Op, P>, Error>,
59 Accessor<T>: From<A> + From<AccessOp<P::Transpose, P>>,
60 {
61 axes.sort();
62 axes.dedup();
63
64 let shape = reduce_axes(&self.shape, &axes, keepdims)?;
65 let size = shape.iter().product::<usize>();
66 let stride = axes.iter().copied().map(|x| self.shape[x]).product();
67 let platform = P::select(size);
68
69 let access = permute_for_reduce(self.platform, self.access, self.shape, axes)?;
70 let access = (op)(self.platform, access, stride)?;
71
72 Ok(Array {
73 access,
74 shape,
75 platform,
76 dtype: PhantomData,
77 })
78 }
79
80 pub fn access(&self) -> &A {
81 &self.access
82 }
83
84 pub fn into_access(self) -> A {
85 self.access
86 }
87}
88
89impl<T, L, P> Array<T, L, P> {
90 fn apply_dual<O, OT, R, Op>(
91 self,
92 other: Array<T, R, P>,
93 op: Op,
94 ) -> Result<Array<OT, AccessOp<O, P>, P>, Error>
95 where
96 P: Copy,
97 Op: Fn(P, L, R) -> Result<AccessOp<O, P>, Error>,
98 {
99 let access = (op)(self.platform, self.access, other.access)?;
100
101 Ok(Array {
102 shape: self.shape,
103 access,
104 platform: self.platform,
105 dtype: PhantomData,
106 })
107 }
108}
109
110impl<T: CType> Array<T, Accessor<T>, Platform> {
112 pub fn from<A, P>(array: Array<T, A, P>) -> Self
113 where
114 Accessor<T>: From<A>,
115 Platform: From<P>,
116 {
117 Self {
118 shape: array.shape,
119 access: array.access.into(),
120 platform: array.platform.into(),
121 dtype: array.dtype,
122 }
123 }
124}
125
126impl<T, B, P> Array<T, AccessBuf<B>, P>
127where
128 T: CType,
129 B: BufferInstance<T>,
130 P: PlatformInstance,
131{
132 fn new_inner(platform: P, buffer: B, shape: Shape) -> Result<Self, Error> {
133 if !shape.is_empty() && shape.iter().product::<usize>() == buffer.len() {
134 let access = buffer.into();
135
136 Ok(Self {
137 shape,
138 access,
139 platform,
140 dtype: PhantomData,
141 })
142 } else {
143 Err(Error::Bounds(format!(
144 "cannot construct an array with shape {shape:?} from a buffer of size {}",
145 buffer.len(),
146 )))
147 }
148 }
149
150 pub fn convert<'a, FB>(buffer: FB, shape: Shape) -> Result<Self, Error>
151 where
152 FB: Into<BufferConverter<'a, T>>,
153 P: Convert<T, Buffer = B>,
154 {
155 let buffer = buffer.into();
156 let platform = P::select(buffer.len());
157 let buffer = platform.convert(buffer)?;
158 Self::new_inner(platform, buffer, shape)
159 }
160
161 pub fn new(buffer: B, shape: Shape) -> Result<Self, Error> {
162 let platform = P::select(buffer.len());
163 Self::new_inner(platform, buffer, shape)
164 }
165}
166
167impl<T, P> Array<T, AccessBuf<P::Buffer>, P>
168where
169 T: CType,
170 P: Constant<T>,
171{
172 pub fn constant(value: T, shape: Shape) -> Result<Self, Error> {
173 if !shape.is_empty() {
174 let size = shape.iter().product();
175 let platform = P::select(size);
176 let buffer = platform.constant(value, size)?;
177 let access = buffer.into();
178
179 Ok(Self {
180 shape,
181 access,
182 platform,
183 dtype: PhantomData,
184 })
185 } else {
186 Err(Error::Bounds(
187 "cannot construct an array with an empty shape".to_string(),
188 ))
189 }
190 }
191}
192
193impl<T, P> Array<T, AccessBuf<P::Buffer>, P>
194where
195 T: CType,
196 P: Convert<T>,
197{
198 pub fn copy<A: Access<T>>(source: &Array<T, A, P>) -> Result<Self, Error> {
199 let buffer = source
200 .buffer()
201 .and_then(|buf| source.platform.convert(buf))?;
202
203 Ok(Self {
204 shape: source.shape.clone(),
205 access: buffer.into(),
206 platform: source.platform,
207 dtype: source.dtype,
208 })
209 }
210}
211
212impl<T: CType, P: PlatformInstance> Array<T, AccessOp<P::Range, P>, P>
214where
215 P: Construct<T>,
216{
217 pub fn range(start: T, stop: T, shape: Shape) -> Result<Self, Error> {
218 let size = shape.iter().product();
219 let platform = P::select(size);
220
221 platform.range(start, stop, size).map(|access| Self {
222 shape,
223 access,
224 platform,
225 dtype: PhantomData,
226 })
227 }
228}
229
230impl<P: PlatformInstance> Array<f32, AccessOp<P::Normal, P>, P>
231where
232 P: Random,
233{
234 pub fn random_normal(size: usize) -> Result<Self, Error> {
235 let platform = P::select(size);
236 let shape = shape![size];
237
238 platform.random_normal(size).map(|access| Self {
239 shape,
240 access,
241 platform,
242 dtype: PhantomData,
243 })
244 }
245}
246
247impl<P: PlatformInstance> Array<f32, AccessOp<P::Uniform, P>, P>
248where
249 P: Random,
250{
251 pub fn random_uniform(size: usize) -> Result<Self, Error> {
252 let platform = P::select(size);
253 let shape = shape![size];
254
255 platform.random_uniform(size).map(|access| Self {
256 shape,
257 access,
258 platform,
259 dtype: PhantomData,
260 })
261 }
262}
263
264impl<T, B, P> Array<T, AccessBuf<B>, P>
266where
267 T: CType,
268 B: BufferInstance<T>,
269 P: PlatformInstance,
270{
271 pub fn as_mut<RB: ?Sized>(&mut self) -> Array<T, AccessBuf<&mut RB>, P>
272 where
273 B: BorrowMut<RB>,
274 {
275 Array {
276 shape: Shape::from_slice(&self.shape),
277 access: self.access.as_mut(),
278 platform: self.platform,
279 dtype: PhantomData,
280 }
281 }
282
283 pub fn as_ref<RB: ?Sized>(&self) -> Array<T, AccessBuf<&RB>, P>
284 where
285 B: Borrow<RB>,
286 {
287 Array {
288 shape: Shape::from_slice(&self.shape),
289 access: self.access.as_ref(),
290 platform: self.platform,
291 dtype: PhantomData,
292 }
293 }
294}
295
296impl<T, O, P> Array<T, AccessOp<O, P>, P>
297where
298 T: CType,
299 O: Enqueue<P, T>,
300 P: PlatformInstance,
301{
302 pub fn as_mut<'a>(&'a mut self) -> Array<T, &'a mut AccessOp<O, P>, P>
303 where
304 O: Write<P, T>,
305 {
306 Array {
307 shape: Shape::from_slice(&self.shape),
308 access: &mut self.access,
309 platform: self.platform,
310 dtype: PhantomData,
311 }
312 }
313
314 pub fn as_ref(&self) -> Array<T, &AccessOp<O, P>, P> {
315 Array {
316 shape: Shape::from_slice(&self.shape),
317 access: &self.access,
318 platform: self.platform,
319 dtype: PhantomData,
320 }
321 }
322}
323
324pub trait NDArray: Send + Sync {
328 type DType: CType;
330
331 type Platform: PlatformInstance;
333
334 fn ndim(&self) -> usize {
336 self.shape().len()
337 }
338
339 fn size(&self) -> usize {
341 self.shape().iter().product()
342 }
343
344 fn shape(&self) -> &[usize];
346}
347
348impl<T, A, P> NDArray for Array<T, A, P>
349where
350 T: CType,
351 A: Access<T>,
352 P: PlatformInstance,
353{
354 type DType = T;
355 type Platform = P;
356
357 fn shape(&self) -> &[usize] {
358 &self.shape
359 }
360}
361
362pub trait NDArrayRead: NDArray + fmt::Debug + Sized {
364 fn buffer(&self) -> Result<BufferConverter<Self::DType>, Error>;
366
367 fn into_read(
369 self,
370 ) -> Result<
371 Array<
372 Self::DType,
373 AccessBuf<<Self::Platform as Convert<Self::DType>>::Buffer>,
374 Self::Platform,
375 >,
376 Error,
377 >
378 where
379 Self::Platform: Convert<Self::DType>;
380
381 fn read_value(&self, coord: &[usize]) -> Result<Self::DType, Error>;
383}
384
385impl<T, A, P> NDArrayRead for Array<T, A, P>
386where
387 T: CType,
388 A: Access<T>,
389 P: PlatformInstance,
390{
391 fn buffer(&self) -> Result<BufferConverter<T>, Error> {
392 self.access.read()
393 }
394
395 fn into_read(self) -> Result<Array<Self::DType, AccessBuf<P::Buffer>, Self::Platform>, Error>
396 where
397 P: Convert<T>,
398 {
399 let buffer = self.buffer().and_then(|buf| self.platform.convert(buf))?;
400 debug_assert_eq!(buffer.len(), self.size());
401
402 Ok(Array {
403 shape: self.shape,
404 access: buffer.into(),
405 platform: self.platform,
406 dtype: self.dtype,
407 })
408 }
409
410 fn read_value(&self, coord: &[usize]) -> Result<T, Error> {
411 valid_coord(coord, self.shape())?;
412
413 let strides = strides_for(self.shape(), self.ndim());
414
415 let offset = coord
416 .iter()
417 .zip(strides)
418 .map(|(i, stride)| i * stride)
419 .sum();
420
421 self.access.read_value(offset)
422 }
423}
424
425pub trait NDArrayWrite: NDArray + fmt::Debug + Sized {
427 fn write<O: NDArrayRead<DType = Self::DType>>(&mut self, other: &O) -> Result<(), Error>;
429
430 fn write_value(&mut self, value: Self::DType) -> Result<(), Error>;
432
433 fn write_value_at(&mut self, coord: &[usize], value: Self::DType) -> Result<(), Error>;
435}
436
437impl<T, A, P> NDArrayWrite for Array<T, A, P>
439where
440 T: CType,
441 A: AccessMut<T>,
442 P: PlatformInstance,
443{
444 fn write<O>(&mut self, other: &O) -> Result<(), Error>
445 where
446 O: NDArrayRead<DType = Self::DType>,
447 {
448 same_shape("write", self.shape(), other.shape())?;
449 other.buffer().and_then(|buf| self.access.write(buf))
450 }
451
452 fn write_value(&mut self, value: Self::DType) -> Result<(), Error> {
453 self.access.write_value(value)
454 }
455
456 fn write_value_at(&mut self, coord: &[usize], value: Self::DType) -> Result<(), Error> {
457 valid_coord(coord, self.shape())?;
458
459 let offset = coord
460 .iter()
461 .zip(strides_for(self.shape(), self.ndim()))
462 .map(|(i, stride)| i * stride)
463 .sum();
464
465 self.access.write_value_at(offset, value)
466 }
467}
468
469pub trait NDArrayCast<OT: CType>: NDArray + Sized {
473 type Output: Access<OT>;
474
475 fn cast(self) -> Result<Array<OT, Self::Output, Self::Platform>, Error>;
477}
478
479impl<IT, OT, A, P> NDArrayCast<OT> for Array<IT, A, P>
480where
481 IT: CType,
482 OT: CType,
483 A: Access<IT>,
484 P: ElementwiseCast<A, IT, OT>,
485{
486 type Output = AccessOp<P::Op, P>;
487
488 fn cast(self) -> Result<Array<OT, AccessOp<P::Op, P>, P>, Error> {
489 Ok(Array {
490 shape: self.shape,
491 access: self.platform.cast(self.access)?,
492 platform: self.platform,
493 dtype: PhantomData,
494 })
495 }
496}
497
498pub trait NDArrayReduce: NDArray + fmt::Debug {
500 type Output: Access<Self::DType>;
501
502 fn max(
504 self,
505 axes: Axes,
506 keepdims: bool,
507 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
508
509 fn min(
511 self,
512 axes: Axes,
513 keepdims: bool,
514 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
515
516 fn product(
518 self,
519 axes: Axes,
520 keepdims: bool,
521 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
522
523 fn sum(
525 self,
526 axes: Axes,
527 keepdims: bool,
528 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
529}
530
531impl<T, A, P> NDArrayReduce for Array<T, A, P>
532where
533 T: CType,
534 A: Access<T>,
535 P: Transform<A, T> + ReduceAxes<Accessor<T>, T>,
536 Accessor<T>: From<A> + From<AccessOp<P::Transpose, P>>,
537{
538 type Output = AccessOp<P::Op, P>;
539
540 fn max(
541 self,
542 axes: Axes,
543 keepdims: bool,
544 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
545 self.reduce_axes(axes, keepdims, |platform, access, stride| {
546 ReduceAxes::max(platform, access, stride)
547 })
548 }
549
550 fn min(
551 self,
552 axes: Axes,
553 keepdims: bool,
554 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
555 self.reduce_axes(axes, keepdims, |platform, access, stride| {
556 ReduceAxes::min(platform, access, stride)
557 })
558 }
559
560 fn product(
561 self,
562 axes: Axes,
563 keepdims: bool,
564 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
565 self.reduce_axes(axes, keepdims, |platform, access, stride| {
566 ReduceAxes::product(platform, access, stride)
567 })
568 }
569
570 fn sum(
571 self,
572 axes: Axes,
573 keepdims: bool,
574 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
575 self.reduce_axes(axes, keepdims, |platform, access, stride| {
576 ReduceAxes::sum(platform, access, stride)
577 })
578 }
579}
580
581pub trait NDArrayTransform: NDArray + Sized + fmt::Debug {
583 type Broadcast: Access<Self::DType>;
585
586 type Slice: Access<Self::DType>;
588
589 type Transpose: Access<Self::DType>;
591
592 fn broadcast(
594 self,
595 shape: Shape,
596 ) -> Result<Array<Self::DType, Self::Broadcast, Self::Platform>, Error>;
597
598 fn reshape(self, shape: Shape) -> Result<Self, Error>;
600
601 fn slice(self, range: Range) -> Result<Array<Self::DType, Self::Slice, Self::Platform>, Error>;
603
604 fn squeeze(self, axes: Axes) -> Result<Self, Error>;
607
608 fn unsqueeze(self, axes: Axes) -> Result<Self, Error>;
610
611 fn transpose(
614 self,
615 permutation: Option<Axes>,
616 ) -> Result<Array<Self::DType, Self::Transpose, Self::Platform>, Error>;
617}
618
619impl<T, A, P> NDArrayTransform for Array<T, A, P>
620where
621 T: CType,
622 A: Access<T>,
623 P: Transform<A, T>,
624{
625 type Broadcast = AccessOp<P::Broadcast, P>;
626 type Slice = AccessOp<P::Slice, P>;
627 type Transpose = AccessOp<P::Transpose, P>;
628
629 fn broadcast(self, shape: Shape) -> Result<Array<T, AccessOp<P::Broadcast, P>, P>, Error> {
630 if !can_broadcast(self.shape(), &shape) {
631 return Err(Error::Bounds(format!(
632 "cannot broadcast {self:?} into {shape:?}"
633 )));
634 }
635
636 let platform = P::select(shape.iter().product());
637 let broadcast = Shape::from_slice(&shape);
638 let access = platform.broadcast(self.access, self.shape, broadcast)?;
639
640 Ok(Array {
641 shape,
642 access,
643 platform,
644 dtype: self.dtype,
645 })
646 }
647
648 fn reshape(mut self, shape: Shape) -> Result<Self, Error> {
649 if shape.iter().product::<usize>() == self.size() {
650 self.shape = shape;
651 Ok(self)
652 } else {
653 Err(Error::Bounds(format!(
654 "cannot reshape an array with shape {:?} into {shape:?}",
655 self.shape
656 )))
657 }
658 }
659
660 fn slice(self, mut range: Range) -> Result<Array<T, AccessOp<P::Slice, P>, P>, Error> {
661 for (dim, range) in self.shape.iter().zip(&range) {
662 match range {
663 AxisRange::At(i) if i < dim => Ok(()),
664 AxisRange::In(start, stop, _step) if start < dim && stop <= dim => Ok(()),
665 AxisRange::Of(indices) if indices.iter().all(|i| i < dim) => Ok(()),
666 range => Err(Error::Bounds(format!(
667 "invalid range {range:?} for dimension {dim}"
668 ))),
669 }?;
670 }
671
672 for dim in self.shape.iter().skip(range.len()).copied() {
673 range.push(AxisRange::In(0, dim, 1));
674 }
675
676 let shape = range_shape(self.shape(), &range);
677 let access = self.platform.slice(self.access, &self.shape, range)?;
678 let platform = P::select(shape.iter().product());
679
680 Ok(Array {
681 shape,
682 access,
683 platform,
684 dtype: self.dtype,
685 })
686 }
687
688 fn squeeze(mut self, mut axes: Axes) -> Result<Self, Error> {
689 if axes.iter().copied().any(|x| x >= self.ndim()) {
690 return Err(Error::Bounds(format!("invalid contraction axes: {axes:?}")));
691 }
692
693 axes.sort();
694
695 for x in axes.into_iter().rev() {
696 self.shape.remove(x);
697 }
698
699 Ok(self)
700 }
701
702 fn unsqueeze(mut self, mut axes: Axes) -> Result<Self, Error> {
703 if axes.iter().copied().any(|x| x > self.ndim()) {
704 return Err(Error::Bounds(format!("invalid expansion axes: {axes:?}")));
705 }
706
707 axes.sort();
708
709 for x in axes.into_iter().rev() {
710 self.shape.insert(x, 1);
711 }
712
713 Ok(self)
714 }
715
716 fn transpose(
717 self,
718 permutation: Option<Axes>,
719 ) -> Result<Array<T, AccessOp<P::Transpose, P>, P>, Error> {
720 let permutation = if let Some(axes) = permutation {
721 if axes.len() == self.ndim()
722 && axes.iter().copied().all(|x| x < self.ndim())
723 && !(1..axes.len())
724 .into_iter()
725 .any(|i| axes[i..].contains(&axes[i - 1]))
726 {
727 Ok(axes)
728 } else {
729 Err(Error::Bounds(format!(
730 "invalid permutation for shape {:?}: {:?}",
731 self.shape, axes
732 )))
733 }
734 } else {
735 Ok((0..self.ndim()).into_iter().rev().collect())
736 }?;
737
738 let shape = permutation.iter().copied().map(|x| self.shape[x]).collect();
739 let platform = self.platform;
740 let access = platform.transpose(self.access, self.shape, permutation)?;
741
742 Ok(Array {
743 shape,
744 access,
745 platform,
746 dtype: self.dtype,
747 })
748 }
749}
750
751pub trait NDArrayUnary: NDArray + Sized {
753 type Output: Access<Self::DType>;
755
756 fn abs(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
758
759 fn exp(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
761
762 fn ln(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
764
765 fn round(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
767}
768
769impl<T, A, P> NDArrayUnary for Array<T, A, P>
770where
771 T: CType,
772 A: Access<T>,
773 P: ElementwiseUnary<A, T>,
774{
775 type Output = AccessOp<P::Op, P>;
776
777 fn abs(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
778 self.apply(|platform, access| platform.abs(access))
779 }
780
781 fn exp(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
782 self.apply(|platform, access| platform.exp(access))
783 }
784
785 fn ln(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>
786 where
787 P: ElementwiseUnary<A, T>,
788 {
789 self.apply(|platform, access| platform.ln(access))
790 }
791
792 fn round(self) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
793 self.apply(|platform, access| platform.round(access))
794 }
795}
796
797pub trait NDArrayUnaryBoolean: NDArray + Sized {
799 type Output: Access<u8>;
801
802 fn not(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
804}
805
806impl<T, A, P> NDArrayUnaryBoolean for Array<T, A, P>
807where
808 T: CType,
809 A: Access<T>,
810 P: ElementwiseUnaryBoolean<A, T>,
811{
812 type Output = AccessOp<P::Op, P>;
813
814 fn not(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
815 self.apply(|platform, access| platform.not(access))
816 }
817}
818
819pub trait NDArrayBoolean<O>: NDArray + Sized
821where
822 O: NDArray<DType = Self::DType>,
823{
824 type Output: Access<u8>;
825
826 fn and(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
828
829 fn or(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
831
832 fn xor(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
834}
835
836impl<T, L, R, P> NDArrayBoolean<Array<T, R, P>> for Array<T, L, P>
837where
838 T: CType,
839 L: Access<T>,
840 R: Access<T>,
841 P: ElementwiseBoolean<L, R, T>,
842{
843 type Output = AccessOp<P::Op, P>;
844
845 fn and(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
846 same_shape("and", self.shape(), other.shape())?;
847 self.apply_dual(other, |platform, left, right| platform.and(left, right))
848 }
849
850 fn or(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
851 same_shape("or", self.shape(), other.shape())?;
852 self.apply_dual(other, |platform, left, right| platform.or(left, right))
853 }
854
855 fn xor(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
856 same_shape("xor", self.shape(), other.shape())?;
857 self.apply_dual(other, |platform, left, right| platform.xor(left, right))
858 }
859}
860
861pub trait NDArrayBooleanScalar: NDArray + Sized {
863 type Output: Access<u8>;
864
865 fn and_scalar(
867 self,
868 other: Self::DType,
869 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
870
871 fn or_scalar(
873 self,
874 other: Self::DType,
875 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
876
877 fn xor_scalar(
879 self,
880 other: Self::DType,
881 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
882}
883
884impl<T, A, P> NDArrayBooleanScalar for Array<T, A, P>
885where
886 T: CType,
887 A: Access<T>,
888 P: ElementwiseBooleanScalar<A, T>,
889{
890 type Output = AccessOp<P::Op, P>;
891
892 fn and_scalar(
893 self,
894 other: Self::DType,
895 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
896 self.apply(|platform, access| platform.and_scalar(access, other))
897 }
898
899 fn or_scalar(
900 self,
901 other: Self::DType,
902 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
903 self.apply(|platform, access| platform.or_scalar(access, other))
904 }
905
906 fn xor_scalar(
907 self,
908 other: Self::DType,
909 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
910 self.apply(|platform, access| platform.xor_scalar(access, other))
911 }
912}
913
914pub trait NDArrayCompare<O: NDArray<DType = Self::DType>>: NDArray + Sized {
916 type Output: Access<u8>;
917
918 fn eq(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
920
921 fn ge(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
923
924 fn gt(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
926
927 fn le(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
929
930 fn lt(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
932
933 fn ne(self, other: O) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
935}
936
937impl<T, L, R, P> NDArrayCompare<Array<T, R, P>> for Array<T, L, P>
938where
939 T: CType,
940 L: Access<T>,
941 R: Access<T>,
942 P: ElementwiseCompare<L, R, T>,
943{
944 type Output = AccessOp<P::Op, P>;
945
946 fn eq(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
947 same_shape("compare", self.shape(), other.shape())?;
948 self.apply_dual(other, |platform, left, right| platform.eq(left, right))
949 }
950
951 fn ge(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
952 same_shape("compare", self.shape(), other.shape())?;
953 self.apply_dual(other, |platform, left, right| platform.ge(left, right))
954 }
955
956 fn gt(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
957 same_shape("compare", self.shape(), other.shape())?;
958 self.apply_dual(other, |platform, left, right| platform.gt(left, right))
959 }
960
961 fn le(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
962 same_shape("compare", self.shape(), other.shape())?;
963 self.apply_dual(other, |platform, left, right| platform.le(left, right))
964 }
965
966 fn lt(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
967 same_shape("compare", self.shape(), other.shape())?;
968 self.apply_dual(other, |platform, left, right| platform.lt(left, right))
969 }
970
971 fn ne(self, other: Array<T, R, P>) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
972 same_shape("compare", self.shape(), other.shape())?;
973 self.apply_dual(other, |platform, left, right| platform.ne(left, right))
974 }
975}
976
977pub trait NDArrayCompareScalar: NDArray + Sized {
979 type Output: Access<u8>;
980
981 fn eq_scalar(
983 self,
984 other: Self::DType,
985 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
986
987 fn gt_scalar(
989 self,
990 other: Self::DType,
991 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
992
993 fn ge_scalar(
995 self,
996 other: Self::DType,
997 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
998
999 fn lt_scalar(
1001 self,
1002 other: Self::DType,
1003 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1004
1005 fn le_scalar(
1007 self,
1008 other: Self::DType,
1009 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1010
1011 fn ne_scalar(
1013 self,
1014 other: Self::DType,
1015 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1016}
1017
1018impl<T, A, P> NDArrayCompareScalar for Array<T, A, P>
1019where
1020 T: CType,
1021 A: Access<T>,
1022 P: ElementwiseScalarCompare<A, T>,
1023{
1024 type Output = AccessOp<P::Op, P>;
1025
1026 fn eq_scalar(
1027 self,
1028 other: Self::DType,
1029 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1030 self.apply(|platform, access| platform.eq_scalar(access, other))
1031 }
1032
1033 fn gt_scalar(
1034 self,
1035 other: Self::DType,
1036 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1037 self.apply(|platform, access| platform.gt_scalar(access, other))
1038 }
1039
1040 fn ge_scalar(
1041 self,
1042 other: Self::DType,
1043 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1044 self.apply(|platform, access| platform.ge_scalar(access, other))
1045 }
1046
1047 fn lt_scalar(
1048 self,
1049 other: Self::DType,
1050 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1051 self.apply(|platform, access| platform.lt_scalar(access, other))
1052 }
1053
1054 fn le_scalar(
1055 self,
1056 other: Self::DType,
1057 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1058 self.apply(|platform, access| platform.le_scalar(access, other))
1059 }
1060
1061 fn ne_scalar(
1062 self,
1063 other: Self::DType,
1064 ) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1065 self.apply(|platform, access| platform.ne_scalar(access, other))
1066 }
1067}
1068
1069pub trait NDArrayMath<O: NDArray<DType = Self::DType>>: NDArray + Sized {
1071 type Output: Access<Self::DType>;
1072
1073 fn add(self, rhs: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1075
1076 fn div(self, rhs: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1078
1079 fn log(self, base: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1081
1082 fn mul(self, rhs: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1084
1085 fn pow(self, exp: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1087
1088 fn sub(self, rhs: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1090
1091 fn rem(self, rhs: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1093}
1094
1095impl<T, L, R, P> NDArrayMath<Array<T, R, P>> for Array<T, L, P>
1096where
1097 T: CType,
1098 L: Access<T>,
1099 R: Access<T>,
1100 P: ElementwiseDual<L, R, T>,
1101{
1102 type Output = AccessOp<P::Op, P>;
1103
1104 fn add(
1105 self,
1106 rhs: Array<T, R, P>,
1107 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1108 same_shape("add", self.shape(), rhs.shape())?;
1109 self.apply_dual(rhs, |platform, left, right| platform.add(left, right))
1110 }
1111
1112 fn div(
1113 self,
1114 rhs: Array<T, R, P>,
1115 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1116 same_shape("div", self.shape(), rhs.shape())?;
1117 self.apply_dual(rhs, |platform, left, right| platform.div(left, right))
1118 }
1119
1120 fn log(
1121 self,
1122 base: Array<T, R, P>,
1123 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1124 same_shape("log", self.shape(), base.shape())?;
1125 self.apply_dual(base, |platform, left, right| platform.log(left, right))
1126 }
1127
1128 fn mul(
1129 self,
1130 rhs: Array<T, R, P>,
1131 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1132 same_shape("mul", self.shape(), rhs.shape())?;
1133 self.apply_dual(rhs, |platform, left, right| platform.mul(left, right))
1134 }
1135
1136 fn pow(
1137 self,
1138 exp: Array<T, R, P>,
1139 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1140 same_shape("pow", self.shape(), exp.shape())?;
1141 self.apply_dual(exp, |platform, left, right| platform.pow(left, right))
1142 }
1143
1144 fn sub(
1145 self,
1146 rhs: Array<T, R, P>,
1147 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1148 same_shape("sub", self.shape(), rhs.shape())?;
1149 self.apply_dual(rhs, |platform, left, right| platform.sub(left, right))
1150 }
1151
1152 fn rem(
1153 self,
1154 rhs: Array<T, R, P>,
1155 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1156 same_shape("rem", self.shape(), rhs.shape())?;
1157 self.apply_dual(rhs, |platform, left, right| platform.rem(left, right))
1158 }
1159}
1160
1161pub trait NDArrayMathScalar: NDArray + Sized {
1163 type Output: Access<Self::DType>;
1164
1165 fn add_scalar(
1167 self,
1168 rhs: Self::DType,
1169 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1170
1171 fn div_scalar(
1173 self,
1174 rhs: Self::DType,
1175 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1176
1177 fn log_scalar(
1179 self,
1180 base: Self::DType,
1181 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1182
1183 fn mul_scalar(
1185 self,
1186 rhs: Self::DType,
1187 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1188
1189 fn pow_scalar(
1191 self,
1192 exp: Self::DType,
1193 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1194
1195 fn rem_scalar(
1197 self,
1198 rhs: Self::DType,
1199 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1200
1201 fn sub_scalar(
1203 self,
1204 rhs: Self::DType,
1205 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1206}
1207
1208impl<T, A, P> NDArrayMathScalar for Array<T, A, P>
1209where
1210 T: CType,
1211 A: Access<T>,
1212 P: ElementwiseScalar<A, T>,
1213{
1214 type Output = AccessOp<P::Op, P>;
1215
1216 fn add_scalar(
1217 self,
1218 rhs: Self::DType,
1219 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1220 self.apply(|platform, left| platform.add_scalar(left, rhs))
1221 }
1222
1223 fn div_scalar(
1224 self,
1225 rhs: Self::DType,
1226 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1227 if rhs != T::ZERO {
1228 self.apply(|platform, left| platform.div_scalar(left, rhs))
1229 } else {
1230 Err(Error::Unsupported(format!(
1231 "cannot divide {self:?} by {rhs}"
1232 )))
1233 }
1234 }
1235
1236 fn log_scalar(
1237 self,
1238 base: Self::DType,
1239 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1240 self.apply(|platform, arg| platform.log_scalar(arg, base))
1241 }
1242
1243 fn mul_scalar(
1244 self,
1245 rhs: Self::DType,
1246 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1247 self.apply(|platform, left| platform.mul_scalar(left, rhs))
1248 }
1249
1250 fn pow_scalar(
1251 self,
1252 exp: Self::DType,
1253 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1254 self.apply(|platform, arg| platform.pow_scalar(arg, exp))
1255 }
1256
1257 fn rem_scalar(
1258 self,
1259 rhs: Self::DType,
1260 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1261 self.apply(|platform, left| platform.rem_scalar(left, rhs))
1262 }
1263
1264 fn sub_scalar(
1265 self,
1266 rhs: Self::DType,
1267 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1268 self.apply(|platform, left| platform.sub_scalar(left, rhs))
1269 }
1270}
1271
1272pub trait NDArrayNumeric: NDArray + Sized
1274where
1275 Self::DType: Float,
1276{
1277 type Output: Access<u8>;
1278
1279 fn is_inf(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1281
1282 fn is_nan(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error>;
1284}
1285
1286impl<T, A, P> NDArrayNumeric for Array<T, A, P>
1287where
1288 T: Float,
1289 A: Access<T>,
1290 P: ElementwiseNumeric<A, T>,
1291{
1292 type Output = AccessOp<P::Op, P>;
1293
1294 fn is_inf(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1295 self.apply(|platform, access| platform.is_inf(access))
1296 }
1297
1298 fn is_nan(self) -> Result<Array<u8, Self::Output, Self::Platform>, Error> {
1299 self.apply(|platform, access| platform.is_nan(access))
1300 }
1301}
1302
1303pub trait NDArrayReduceBoolean: NDArrayRead {
1305 fn all(self) -> Result<bool, Error>;
1307
1308 fn any(self) -> Result<bool, Error>;
1310}
1311
1312impl<T, A, P> NDArrayReduceBoolean for Array<T, A, P>
1313where
1314 T: CType,
1315 A: Access<T>,
1316 P: ReduceAll<A, T>,
1317{
1318 fn all(self) -> Result<bool, Error> {
1319 self.platform.all(self.access)
1320 }
1321
1322 fn any(self) -> Result<bool, Error> {
1323 self.platform.any(self.access)
1324 }
1325}
1326
1327pub trait NDArrayReduceAll: NDArrayRead {
1329 fn max_all(self) -> Result<Self::DType, Error>;
1331
1332 fn min_all(self) -> Result<Self::DType, Error>;
1334
1335 fn product_all(self) -> Result<Self::DType, Error>;
1337
1338 fn sum_all(self) -> Result<Self::DType, Error>;
1340}
1341
1342impl<'a, T, A, P> NDArrayReduceAll for Array<T, A, P>
1343where
1344 T: CType,
1345 A: Access<T>,
1346 P: ReduceAll<A, T>,
1347{
1348 fn max_all(self) -> Result<Self::DType, Error> {
1349 self.platform.max(self.access)
1350 }
1351
1352 fn min_all(self) -> Result<Self::DType, Error> {
1353 self.platform.min(self.access)
1354 }
1355
1356 fn product_all(self) -> Result<Self::DType, Error> {
1357 self.platform.product(self.access)
1358 }
1359
1360 fn sum_all(self) -> Result<T, Error> {
1361 self.platform.sum(self.access)
1362 }
1363}
1364
1365impl<T, A, P> fmt::Debug for Array<T, A, P> {
1366 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1367 write!(
1368 f,
1369 "a {} array of shape {:?}",
1370 std::any::type_name::<T>(),
1371 self.shape
1372 )
1373 }
1374}
1375
1376pub trait NDArrayTrig: NDArray + Sized {
1378 type Output: Access<<Self::DType as CType>::Float>;
1379
1380 fn sin(
1382 self,
1383 ) -> Result<Array<<Self::DType as CType>::Float, Self::Output, Self::Platform>, Error>;
1384
1385 fn asin(
1387 self,
1388 ) -> Result<Array<<Self::DType as CType>::Float, Self::Output, Self::Platform>, Error>;
1389
1390 fn sinh(
1392 self,
1393 ) -> Result<Array<<Self::DType as CType>::Float, Self::Output, Self::Platform>, Error>;
1394
1395 fn cos(
1397 self,
1398 ) -> Result<Array<<Self::DType as CType>::Float, Self::Output, Self::Platform>, Error>;
1399
1400 fn acos(
1402 self,
1403 ) -> Result<Array<<Self::DType as CType>::Float, Self::Output, Self::Platform>, Error>;
1404
1405 fn cosh(
1407 self,
1408 ) -> Result<Array<<Self::DType as CType>::Float, Self::Output, Self::Platform>, Error>;
1409
1410 fn tan(
1412 self,
1413 ) -> Result<Array<<Self::DType as CType>::Float, Self::Output, Self::Platform>, Error>;
1414
1415 fn atan(
1417 self,
1418 ) -> Result<Array<<Self::DType as CType>::Float, Self::Output, Self::Platform>, Error>;
1419
1420 fn tanh(
1422 self,
1423 ) -> Result<Array<<Self::DType as CType>::Float, Self::Output, Self::Platform>, Error>;
1424}
1425
1426impl<T, A, P> NDArrayTrig for Array<T, A, P>
1427where
1428 T: CType,
1429 A: Access<T>,
1430 P: ElementwiseTrig<A, T>,
1431{
1432 type Output = AccessOp<P::Op, P>;
1433
1434 fn sin(self) -> Result<Array<T::Float, Self::Output, Self::Platform>, Error> {
1435 self.apply(|platform, access| platform.sin(access))
1436 }
1437
1438 fn asin(self) -> Result<Array<T::Float, Self::Output, Self::Platform>, Error> {
1439 self.apply(|platform, access| platform.asin(access))
1440 }
1441
1442 fn sinh(self) -> Result<Array<T::Float, Self::Output, Self::Platform>, Error> {
1443 self.apply(|platform, access| platform.sinh(access))
1444 }
1445
1446 fn cos(self) -> Result<Array<T::Float, Self::Output, Self::Platform>, Error> {
1447 self.apply(|platform, access| platform.cos(access))
1448 }
1449
1450 fn acos(self) -> Result<Array<T::Float, Self::Output, Self::Platform>, Error> {
1451 self.apply(|platform, access| platform.acos(access))
1452 }
1453
1454 fn cosh(self) -> Result<Array<T::Float, Self::Output, Self::Platform>, Error> {
1455 self.apply(|platform, access| platform.cosh(access))
1456 }
1457
1458 fn tan(self) -> Result<Array<T::Float, Self::Output, Self::Platform>, Error> {
1459 self.apply(|platform, access| platform.tan(access))
1460 }
1461
1462 fn atan(self) -> Result<Array<T::Float, Self::Output, Self::Platform>, Error> {
1463 self.apply(|platform, access| platform.atan(access))
1464 }
1465
1466 fn tanh(self) -> Result<Array<T::Float, Self::Output, Self::Platform>, Error> {
1467 self.apply(|platform, access| platform.tanh(access))
1468 }
1469}
1470
1471pub trait NDArrayWhere<T, L, R>: NDArray<DType = u8> + fmt::Debug
1473where
1474 T: CType,
1475{
1476 type Output: Access<T>;
1477
1478 fn cond(self, then: L, or_else: R) -> Result<Array<T, Self::Output, Self::Platform>, Error>;
1482}
1483
1484impl<T, A, L, R, P> NDArrayWhere<T, Array<T, L, P>, Array<T, R, P>> for Array<u8, A, P>
1485where
1486 T: CType,
1487 A: Access<u8>,
1488 L: Access<T>,
1489 R: Access<T>,
1490 P: GatherCond<A, L, R, T>,
1491{
1492 type Output = AccessOp<P::Op, P>;
1493
1494 fn cond(
1495 self,
1496 then: Array<T, L, P>,
1497 or_else: Array<T, R, P>,
1498 ) -> Result<Array<T, Self::Output, Self::Platform>, Error> {
1499 same_shape("cond", self.shape(), then.shape())?;
1500 same_shape("cond", self.shape(), or_else.shape())?;
1501
1502 let access = self
1503 .platform
1504 .cond(self.access, then.access, or_else.access)?;
1505
1506 Ok(Array {
1507 shape: self.shape,
1508 access,
1509 platform: self.platform,
1510 dtype: PhantomData,
1511 })
1512 }
1513}
1514
1515pub trait MatrixDual<O>: NDArray + fmt::Debug
1517where
1518 O: NDArray<DType = Self::DType> + fmt::Debug,
1519{
1520 type Output: Access<Self::DType>;
1521
1522 fn matmul(self, other: O) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error>;
1524}
1525
1526impl<T, L, R, P> MatrixDual<Array<T, R, P>> for Array<T, L, P>
1527where
1528 T: CType,
1529 L: Access<T>,
1530 R: Access<T>,
1531 P: LinAlgDual<L, R, T>,
1532{
1533 type Output = AccessOp<P::Op, P>;
1534
1535 fn matmul(
1536 self,
1537 other: Array<T, R, P>,
1538 ) -> Result<Array<Self::DType, Self::Output, Self::Platform>, Error> {
1539 let dims = matmul_dims(&self.shape, &other.shape).ok_or_else(|| {
1540 Error::Bounds(format!(
1541 "invalid dimensions for matrix multiply: {:?} and {:?}",
1542 self.shape, other.shape
1543 ))
1544 })?;
1545
1546 let mut shape = Shape::with_capacity(self.ndim());
1547 shape.extend(self.shape.iter().rev().skip(2).rev().copied());
1548 shape.push(dims[1]);
1549 shape.push(dims[3]);
1550
1551 let platform = P::select(dims.iter().product());
1552
1553 let access = platform.matmul(self.access, other.access, dims)?;
1554
1555 Ok(Array {
1556 shape,
1557 access,
1558 platform,
1559 dtype: self.dtype,
1560 })
1561 }
1562}
1563
1564pub trait MatrixUnary: NDArray + fmt::Debug {
1566 type Diag: Access<Self::DType>;
1567
1568 fn diag(self) -> Result<Array<Self::DType, Self::Diag, Self::Platform>, Error>;
1571}
1572
1573impl<T, A, P> MatrixUnary for Array<T, A, P>
1574where
1575 T: CType,
1576 A: Access<T>,
1577 P: LinAlgUnary<A, T>,
1578{
1579 type Diag = AccessOp<P::Op, P>;
1580
1581 fn diag(self) -> Result<Array<T, AccessOp<P::Op, P>, P>, Error> {
1582 if self.ndim() >= 2 && self.shape.last() == self.shape.iter().nth_back(1) {
1583 let batch_size = self.shape.iter().rev().skip(2).product();
1584 let dim = self.shape.last().copied().expect("dim");
1585
1586 let shape = self.shape.iter().rev().skip(1).rev().copied().collect();
1587 let platform = P::select(batch_size * dim * dim);
1588 let access = platform.diag(self.access, batch_size, dim)?;
1589
1590 Ok(Array {
1591 shape,
1592 access,
1593 platform,
1594 dtype: PhantomData,
1595 })
1596 } else {
1597 Err(Error::Bounds(format!(
1598 "invalid shape for diagonal: {:?}",
1599 self.shape
1600 )))
1601 }
1602 }
1603}
1604
1605#[inline]
1606fn can_broadcast(left: &[usize], right: &[usize]) -> bool {
1607 if left.len() < right.len() {
1608 return can_broadcast(right, left);
1609 }
1610
1611 for (l, r) in left.iter().copied().rev().zip(right.iter().copied().rev()) {
1612 if l == r || l == 1 || r == 1 {
1613 } else {
1615 return false;
1616 }
1617 }
1618
1619 true
1620}
1621
1622#[inline]
1623fn matmul_dims(left: &[usize], right: &[usize]) -> Option<[usize; 4]> {
1624 let mut left = left.into_iter().copied().rev();
1625 let mut right = right.into_iter().copied().rev();
1626
1627 let b = left.next()?;
1628 let a = left.next()?;
1629
1630 let c = right.next()?;
1631 if right.next()? != b {
1632 return None;
1633 }
1634
1635 let mut batch_size = 1;
1636 loop {
1637 match (left.next(), right.next()) {
1638 (Some(l), Some(r)) if l == r => {
1639 batch_size *= l;
1640 }
1641 (None, None) => break,
1642 _ => return None,
1643 }
1644 }
1645
1646 Some([batch_size, a, b, c])
1647}
1648
1649#[inline]
1650fn permute_for_reduce<T, A, P>(
1651 platform: P,
1652 access: A,
1653 shape: Shape,
1654 axes: Axes,
1655) -> Result<Accessor<T>, Error>
1656where
1657 T: CType,
1658 A: Access<T>,
1659 P: Transform<A, T>,
1660 Accessor<T>: From<A> + From<AccessOp<P::Transpose, P>>,
1661{
1662 let mut permutation = Axes::with_capacity(shape.len());
1663 permutation.extend((0..shape.len()).into_iter().filter(|x| !axes.contains(x)));
1664 permutation.extend(axes);
1665
1666 if permutation.iter().copied().enumerate().all(|(i, x)| i == x) {
1667 Ok(Accessor::from(access))
1668 } else {
1669 platform
1670 .transpose(access, shape, permutation)
1671 .map(Accessor::from)
1672 }
1673}
1674
1675#[inline]
1676fn reduce_axes(shape: &[usize], axes: &[usize], keepdims: bool) -> Result<Shape, Error> {
1677 let mut shape = Shape::from_slice(shape);
1678
1679 for x in axes.iter().copied().rev() {
1680 if x >= shape.len() {
1681 return Err(Error::Bounds(format!(
1682 "axis {x} is out of bounds for {shape:?}"
1683 )));
1684 } else if keepdims {
1685 shape[x] = 1;
1686 } else {
1687 shape.remove(x);
1688 }
1689 }
1690
1691 if shape.is_empty() {
1692 Ok(shape![1])
1693 } else {
1694 Ok(shape)
1695 }
1696}
1697
1698#[inline]
1699fn same_shape(op_name: &'static str, left: &[usize], right: &[usize]) -> Result<(), Error> {
1700 if left == right {
1701 Ok(())
1702 } else if can_broadcast(left, right) {
1703 Err(Error::Bounds(format!(
1704 "cannot {op_name} arrays with shapes {left:?} and {right:?} (consider broadcasting)"
1705 )))
1706 } else {
1707 Err(Error::Bounds(format!(
1708 "cannot {op_name} arrays with shapes {left:?} and {right:?}"
1709 )))
1710 }
1711}
1712
1713#[inline]
1714fn valid_coord(coord: &[usize], shape: &[usize]) -> Result<(), Error> {
1715 if coord.len() == shape.len() {
1716 if coord.iter().zip(shape).all(|(i, dim)| i < dim) {
1717 return Ok(());
1718 }
1719 }
1720
1721 Err(Error::Bounds(format!(
1722 "invalid coordinate {coord:?} for shape {shape:?}"
1723 )))
1724}