mdarray/
ops.rs

1#[cfg(feature = "nightly")]
2use alloc::alloc::Allocator;
3
4use core::ops::{
5    Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign,
6    Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
7};
8
9#[cfg(not(feature = "nightly"))]
10use crate::allocator::Allocator;
11use crate::array::Array;
12use crate::expr::{Apply, Buffer, Expression, IntoExpression};
13use crate::expr::{Fill, FillWith, FromElem, FromFn, IntoExpr, Map};
14use crate::layout::Layout;
15use crate::shape::{ConstShape, Shape};
16use crate::slice::Slice;
17use crate::tensor::Tensor;
18use crate::view::{View, ViewMut};
19
20/// Range constructed from a unit spaced range with the given step size.
21#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
22pub struct StepRange<R, S> {
23    /// Unit spaced range.
24    pub range: R,
25
26    /// Step size.
27    pub step: S,
28}
29
30/// Creates a range with the given step size from a unit spaced range.
31///
32/// If the step size is negative, the result is obtained by reversing the input range
33/// and stepping by the absolute value of the step size.
34///
35/// # Examples
36///
37/// ```
38/// use mdarray::{step, view};
39///
40/// let v = view![0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
41///
42/// assert_eq!(v.view(step(0..10, 2)).to_vec(), [0, 2, 4, 6, 8]);
43/// assert_eq!(v.view(step(0..10, -2)).to_vec(), [9, 7, 5, 3, 1]);
44/// ```
45#[inline]
46pub fn step<R, S>(range: R, step: S) -> StepRange<R, S> {
47    StepRange { range, step }
48}
49
50impl<T: Eq, S: ConstShape> Eq for Array<T, S> {}
51impl<T: Eq, S: Shape, L: Layout> Eq for Slice<T, S, L> {}
52impl<T: Eq, S: Shape, A: Allocator> Eq for Tensor<T, S, A> {}
53impl<T: Eq, S: Shape, L: Layout> Eq for View<'_, T, S, L> {}
54impl<T: Eq, S: Shape, L: Layout> Eq for ViewMut<'_, T, S, L> {}
55
56impl<T, U, S: ConstShape, R: Shape, L: Layout, I: ?Sized> PartialEq<I> for Array<T, S>
57where
58    for<'a> &'a I: IntoExpression<IntoExpr = View<'a, U, R, L>>,
59    T: PartialEq<U>,
60{
61    #[inline]
62    fn eq(&self, other: &I) -> bool {
63        (**self).eq(other)
64    }
65}
66
67impl<T, U, S: Shape, R: Shape, L: Layout, K: Layout, I: ?Sized> PartialEq<I> for Slice<T, S, L>
68where
69    for<'a> &'a I: IntoExpression<IntoExpr = View<'a, U, R, K>>,
70    T: PartialEq<U>,
71{
72    #[inline]
73    fn eq(&self, other: &I) -> bool {
74        let other = other.into_expr();
75
76        if self.shape().with_dims(|dims| other.shape().with_dims(|other| dims == other)) {
77            // Avoid very long compile times for release build with MIR inlining,
78            // by avoiding recursion until types are known.
79            //
80            // This is a workaround until const if is available, see #3582 and #122301.
81
82            #[inline]
83            fn compare_dense<T, U, S: Shape, R: Shape, L: Layout, K: Layout>(
84                this: &Slice<T, S, L>,
85                other: &Slice<U, R, K>,
86            ) -> bool
87            where
88                T: PartialEq<U>,
89            {
90                this.remap::<S, _>()[..].eq(&other.remap::<R, _>()[..])
91            }
92
93            #[inline]
94            fn compare_strided<T, U, S: Shape, R: Shape, L: Layout, K: Layout>(
95                this: &Slice<T, S, L>,
96                other: &Slice<U, R, K>,
97            ) -> bool
98            where
99                T: PartialEq<U>,
100            {
101                if this.rank() < 2 {
102                    this.iter().eq(other)
103                } else {
104                    this.outer_expr().into_iter().eq(other.outer_expr())
105                }
106            }
107
108            let f =
109                const { if L::IS_DENSE && K::IS_DENSE { compare_dense } else { compare_strided } };
110
111            f(self, &other)
112        } else {
113            false
114        }
115    }
116}
117
118impl<T, U, S: Shape, R: Shape, L: Layout, A: Allocator, I: ?Sized> PartialEq<I> for Tensor<T, S, A>
119where
120    for<'a> &'a I: IntoExpression<IntoExpr = View<'a, U, R, L>>,
121    T: PartialEq<U>,
122{
123    #[inline]
124    fn eq(&self, other: &I) -> bool {
125        (**self).eq(other)
126    }
127}
128
129impl<T, U, S: Shape, R: Shape, L: Layout, K: Layout, I: ?Sized> PartialEq<I> for View<'_, T, S, L>
130where
131    for<'a> &'a I: IntoExpression<IntoExpr = View<'a, U, R, K>>,
132    T: PartialEq<U>,
133{
134    #[inline]
135    fn eq(&self, other: &I) -> bool {
136        (**self).eq(other)
137    }
138}
139
140impl<T, U, S: Shape, R: Shape, L: Layout, K: Layout, I: ?Sized> PartialEq<I>
141    for ViewMut<'_, T, S, L>
142where
143    for<'a> &'a I: IntoExpression<IntoExpr = View<'a, U, R, K>>,
144    T: PartialEq<U>,
145{
146    #[inline]
147    fn eq(&self, other: &I) -> bool {
148        (**self).eq(other)
149    }
150}
151
152macro_rules! impl_binary_op {
153    ($trt:tt, $fn:tt) => {
154        impl<'a, T, U, S: ConstShape, I: Apply<U>> $trt<I> for &'a Array<T, S>
155        where
156            &'a T: $trt<I::Item, Output = U>,
157        {
158            #[cfg(not(feature = "nightly"))]
159            type Output = I::ZippedWith<Self, fn((I::Item, &'a T)) -> U>;
160
161            #[cfg(feature = "nightly")]
162            type Output = I::ZippedWith<Self, impl FnMut((I::Item, &'a T)) -> U>;
163
164            #[inline]
165            fn $fn(self, rhs: I) -> Self::Output {
166                rhs.zip_with(self, |(x, y)| y.$fn(x))
167            }
168        }
169
170        impl<'a, T, U, S: Shape, L: Layout, I: Apply<U>> $trt<I> for &'a Slice<T, S, L>
171        where
172            &'a T: $trt<I::Item, Output = U>,
173        {
174            #[cfg(not(feature = "nightly"))]
175            type Output = I::ZippedWith<Self, fn((I::Item, &'a T)) -> U>;
176
177            #[cfg(feature = "nightly")]
178            type Output = I::ZippedWith<Self, impl FnMut((I::Item, &'a T)) -> U>;
179
180            #[inline]
181            fn $fn(self, rhs: I) -> Self::Output {
182                rhs.zip_with(self, |(x, y)| y.$fn(x))
183            }
184        }
185
186        impl<'a, T, U, S: Shape, A: Allocator, I: Apply<U>> $trt<I> for &'a Tensor<T, S, A>
187        where
188            &'a T: $trt<I::Item, Output = U>,
189        {
190            #[cfg(not(feature = "nightly"))]
191            type Output = I::ZippedWith<Self, fn((I::Item, &'a T)) -> U>;
192
193            #[cfg(feature = "nightly")]
194            type Output = I::ZippedWith<Self, impl FnMut((I::Item, &'a T)) -> U>;
195
196            #[inline]
197            fn $fn(self, rhs: I) -> Self::Output {
198                rhs.zip_with(self, |(x, y)| y.$fn(x))
199            }
200        }
201
202        impl<'a, T, U, S: Shape, L: Layout, I: Apply<U>> $trt<I> for &'a View<'_, T, S, L>
203        where
204            &'a T: $trt<I::Item, Output = U>,
205        {
206            #[cfg(not(feature = "nightly"))]
207            type Output = I::ZippedWith<Self, fn((I::Item, &'a T)) -> U>;
208
209            #[cfg(feature = "nightly")]
210            type Output = I::ZippedWith<Self, impl FnMut((I::Item, &'a T)) -> U>;
211
212            #[inline]
213            fn $fn(self, rhs: I) -> Self::Output {
214                rhs.zip_with(self, |(x, y)| y.$fn(x))
215            }
216        }
217
218        impl<'a, T, U, S: Shape, L: Layout, I: Apply<U>> $trt<I> for &'a ViewMut<'_, T, S, L>
219        where
220            &'a T: $trt<I::Item, Output = U>,
221        {
222            #[cfg(not(feature = "nightly"))]
223            type Output = I::ZippedWith<Self, fn((I::Item, &'a T)) -> U>;
224
225            #[cfg(feature = "nightly")]
226            type Output = I::ZippedWith<Self, impl FnMut((I::Item, &'a T)) -> U>;
227
228            #[inline]
229            fn $fn(self, rhs: I) -> Self::Output {
230                rhs.zip_with(self, |(x, y)| y.$fn(x))
231            }
232        }
233
234        impl<T, U, S: ConstShape, I: IntoExpression> $trt<I> for Array<T, S>
235        where
236            T: $trt<I::Item, Output = U>,
237        {
238            type Output = Array<U, S>;
239
240            #[inline]
241            fn $fn(self, rhs: I) -> Self::Output {
242                self.zip_with(rhs, |(x, y)| x.$fn(y))
243            }
244        }
245
246        impl<T: Clone, U, I: Apply<U>> $trt<I> for Fill<T>
247        where
248            T: $trt<I::Item, Output = U>,
249        {
250            #[cfg(not(feature = "nightly"))]
251            type Output = I::ZippedWith<Self, fn((I::Item, T)) -> U>;
252
253            #[cfg(feature = "nightly")]
254            type Output = I::ZippedWith<Self, impl FnMut((I::Item, T)) -> U>;
255
256            #[inline]
257            fn $fn(self, rhs: I) -> Self::Output {
258                rhs.zip_with(self, |(x, y)| y.$fn(x))
259            }
260        }
261
262        impl<T: Clone, U, F: FnMut() -> T, I: Apply<U>> $trt<I> for FillWith<F>
263        where
264            T: $trt<I::Item, Output = U>,
265        {
266            #[cfg(not(feature = "nightly"))]
267            type Output = I::ZippedWith<Self, fn((I::Item, T)) -> U>;
268
269            #[cfg(feature = "nightly")]
270            type Output = I::ZippedWith<Self, impl FnMut((I::Item, T)) -> U>;
271
272            #[inline]
273            fn $fn(self, rhs: I) -> Self::Output {
274                rhs.zip_with(self, |(x, y)| y.$fn(x))
275            }
276        }
277
278        impl<S: Shape, T: Clone, U, I: Apply<U>> $trt<I> for FromElem<T, S>
279        where
280            T: $trt<I::Item, Output = U>,
281        {
282            #[cfg(not(feature = "nightly"))]
283            type Output = I::ZippedWith<Self, fn((I::Item, T)) -> U>;
284
285            #[cfg(feature = "nightly")]
286            type Output = I::ZippedWith<Self, impl FnMut((I::Item, T)) -> U>;
287
288            #[inline]
289            fn $fn(self, rhs: I) -> Self::Output {
290                rhs.zip_with(self, |(x, y)| y.$fn(x))
291            }
292        }
293
294        impl<S: Shape, T, U, F: FnMut(&[usize]) -> T, I: Apply<U>> $trt<I> for FromFn<S, F>
295        where
296            T: $trt<I::Item, Output = U>,
297        {
298            #[cfg(not(feature = "nightly"))]
299            type Output = I::ZippedWith<Self, fn((I::Item, T)) -> U>;
300
301            #[cfg(feature = "nightly")]
302            type Output = I::ZippedWith<Self, impl FnMut((I::Item, T)) -> U>;
303
304            #[inline]
305            fn $fn(self, rhs: I) -> Self::Output {
306                rhs.zip_with(self, |(x, y)| y.$fn(x))
307            }
308        }
309
310        impl<T, B: Buffer, I: Apply<T>> $trt<I> for IntoExpr<B>
311        where
312            B::Item: $trt<I::Item, Output = T>,
313        {
314            #[cfg(not(feature = "nightly"))]
315            type Output = I::ZippedWith<Self, fn((I::Item, B::Item)) -> T>;
316
317            #[cfg(feature = "nightly")]
318            type Output = I::ZippedWith<Self, impl FnMut((I::Item, B::Item)) -> T>;
319
320            #[inline]
321            fn $fn(self, rhs: I) -> Self::Output {
322                rhs.zip_with(self, |(x, y)| y.$fn(x))
323            }
324        }
325
326        impl<T, U, E: Expression, F: FnMut(E::Item) -> T, I: Apply<U>> $trt<I> for Map<E, F>
327        where
328            T: $trt<I::Item, Output = U>,
329        {
330            #[cfg(not(feature = "nightly"))]
331            type Output = I::ZippedWith<Self, fn((I::Item, T)) -> U>;
332
333            #[cfg(feature = "nightly")]
334            type Output = I::ZippedWith<Self, impl FnMut((I::Item, T)) -> U>;
335
336            #[inline]
337            fn $fn(self, rhs: I) -> Self::Output {
338                rhs.zip_with(self, |(x, y)| y.$fn(x))
339            }
340        }
341
342        impl<T, S: Shape, A: Allocator, I: IntoExpression> $trt<I> for Tensor<T, S, A>
343        where
344            T: $trt<I::Item, Output = T>,
345        {
346            type Output = Self;
347
348            #[inline]
349            fn $fn(self, rhs: I) -> Self {
350                self.zip_with(rhs, |(x, y)| x.$fn(y))
351            }
352        }
353
354        impl<'a, T, U, S: Shape, L: Layout, I: Apply<U>> $trt<I> for View<'a, T, S, L>
355        where
356            &'a T: $trt<I::Item, Output = U>,
357        {
358            #[cfg(not(feature = "nightly"))]
359            type Output = I::ZippedWith<Self, fn((I::Item, &'a T)) -> U>;
360
361            #[cfg(feature = "nightly")]
362            type Output = I::ZippedWith<Self, impl FnMut((I::Item, &'a T)) -> U>;
363
364            #[inline]
365            fn $fn(self, rhs: I) -> Self::Output {
366                rhs.zip_with(self, |(x, y)| y.$fn(x))
367            }
368        }
369    };
370}
371
372impl_binary_op!(Add, add);
373impl_binary_op!(Sub, sub);
374impl_binary_op!(Mul, mul);
375impl_binary_op!(Div, div);
376impl_binary_op!(Rem, rem);
377impl_binary_op!(BitAnd, bitand);
378impl_binary_op!(BitOr, bitor);
379impl_binary_op!(BitXor, bitxor);
380impl_binary_op!(Shl, shl);
381impl_binary_op!(Shr, shr);
382
383macro_rules! impl_op_assign {
384    ($trt:tt, $fn:tt) => {
385        impl<T, S: ConstShape, I: IntoExpression> $trt<I> for Array<T, S>
386        where
387            T: $trt<I::Item>,
388        {
389            #[inline]
390            fn $fn(&mut self, rhs: I) {
391                self.expr_mut().zip(rhs).for_each(|(x, y)| x.$fn(y));
392            }
393        }
394
395        impl<T, S: Shape, L: Layout, I: IntoExpression> $trt<I> for Slice<T, S, L>
396        where
397            T: $trt<I::Item>,
398        {
399            #[inline]
400            fn $fn(&mut self, rhs: I) {
401                self.expr_mut().zip(rhs).for_each(|(x, y)| x.$fn(y));
402            }
403        }
404
405        impl<T, S: Shape, A: Allocator, I: IntoExpression> $trt<I> for Tensor<T, S, A>
406        where
407            T: $trt<I::Item>,
408        {
409            #[inline]
410            fn $fn(&mut self, rhs: I) {
411                self.expr_mut().zip(rhs).for_each(|(x, y)| x.$fn(y));
412            }
413        }
414
415        impl<T, S: Shape, L: Layout, I: IntoExpression> $trt<I> for ViewMut<'_, T, S, L>
416        where
417            T: $trt<I::Item>,
418        {
419            #[inline]
420            fn $fn(&mut self, rhs: I) {
421                self.expr_mut().zip(rhs).for_each(|(x, y)| x.$fn(y));
422            }
423        }
424    };
425}
426
427impl_op_assign!(AddAssign, add_assign);
428impl_op_assign!(SubAssign, sub_assign);
429impl_op_assign!(MulAssign, mul_assign);
430impl_op_assign!(DivAssign, div_assign);
431impl_op_assign!(RemAssign, rem_assign);
432impl_op_assign!(BitAndAssign, bitand_assign);
433impl_op_assign!(BitOrAssign, bitor_assign);
434impl_op_assign!(BitXorAssign, bitxor_assign);
435impl_op_assign!(ShlAssign, shl_assign);
436impl_op_assign!(ShrAssign, shr_assign);
437
438macro_rules! impl_unary_op {
439    ($trt:tt, $fn:tt) => {
440        impl<'a, T, U, S: ConstShape> $trt for &'a Array<T, S>
441        where
442            &'a T: $trt<Output = U>,
443        {
444            #[cfg(not(feature = "nightly"))]
445            type Output = <Self as Apply<U>>::Output<fn(&'a T) -> U>;
446
447            #[cfg(feature = "nightly")]
448            type Output = <Self as Apply<U>>::Output<impl FnMut(&'a T) -> U>;
449
450            #[inline]
451            fn $fn(self) -> Self::Output {
452                self.apply(|x| x.$fn())
453            }
454        }
455
456        impl<'a, T, U, S: Shape, L: Layout> $trt for &'a Slice<T, S, L>
457        where
458            &'a T: $trt<Output = U>,
459        {
460            #[cfg(not(feature = "nightly"))]
461            type Output = <Self as Apply<U>>::Output<fn(&'a T) -> U>;
462
463            #[cfg(feature = "nightly")]
464            type Output = <Self as Apply<U>>::Output<impl FnMut(&'a T) -> U>;
465
466            #[inline]
467            fn $fn(self) -> Self::Output {
468                self.apply(|x| x.$fn())
469            }
470        }
471
472        impl<'a, T, U, S: Shape, A: Allocator> $trt for &'a Tensor<T, S, A>
473        where
474            &'a T: $trt<Output = U>,
475        {
476            #[cfg(not(feature = "nightly"))]
477            type Output = <Self as Apply<U>>::Output<fn(&'a T) -> U>;
478
479            #[cfg(feature = "nightly")]
480            type Output = <Self as Apply<U>>::Output<impl FnMut(&'a T) -> U>;
481
482            #[inline]
483            fn $fn(self) -> Self::Output {
484                self.apply(|x| x.$fn())
485            }
486        }
487
488        impl<'a, T, U, S: Shape, L: Layout> $trt for &'a View<'_, T, S, L>
489        where
490            &'a T: $trt<Output = U>,
491        {
492            #[cfg(not(feature = "nightly"))]
493            type Output = <Self as Apply<U>>::Output<fn(&'a T) -> U>;
494
495            #[cfg(feature = "nightly")]
496            type Output = <Self as Apply<U>>::Output<impl FnMut(&'a T) -> U>;
497
498            #[inline]
499            fn $fn(self) -> Self::Output {
500                self.apply(|x| x.$fn())
501            }
502        }
503
504        impl<'a, T, U, S: Shape, L: Layout> $trt for &'a ViewMut<'_, T, S, L>
505        where
506            &'a T: $trt<Output = U>,
507        {
508            #[cfg(not(feature = "nightly"))]
509            type Output = <Self as Apply<U>>::Output<fn(&'a T) -> U>;
510
511            #[cfg(feature = "nightly")]
512            type Output = <Self as Apply<U>>::Output<impl FnMut(&'a T) -> U>;
513
514            #[inline]
515            fn $fn(self) -> Self::Output {
516                self.apply(|x| x.$fn())
517            }
518        }
519
520        impl<T, U, S: ConstShape> $trt for Array<T, S>
521        where
522            T: $trt<Output = U>,
523        {
524            type Output = Array<U, S>;
525
526            #[inline]
527            fn $fn(self) -> Self::Output {
528                self.apply(|x| x.$fn())
529            }
530        }
531
532        impl<T: Clone, U> $trt for Fill<T>
533        where
534            T: $trt<Output = U>,
535        {
536            #[cfg(not(feature = "nightly"))]
537            type Output = <Self as Apply<U>>::Output<fn(T) -> U>;
538
539            #[cfg(feature = "nightly")]
540            type Output = <Self as Apply<U>>::Output<impl FnMut(T) -> U>;
541
542            #[inline]
543            fn $fn(self) -> Self::Output {
544                self.apply(|x| x.$fn())
545            }
546        }
547
548        impl<T: Clone, U, F: FnMut() -> T> $trt for FillWith<F>
549        where
550            T: $trt<Output = U>,
551        {
552            #[cfg(not(feature = "nightly"))]
553            type Output = <Self as Apply<U>>::Output<fn(T) -> U>;
554
555            #[cfg(feature = "nightly")]
556            type Output = <Self as Apply<U>>::Output<impl FnMut(T) -> U>;
557
558            #[inline]
559            fn $fn(self) -> Self::Output {
560                self.apply(|x| x.$fn())
561            }
562        }
563
564        impl<S: Shape, T: Clone, U> $trt for FromElem<T, S>
565        where
566            T: $trt<Output = U>,
567        {
568            #[cfg(not(feature = "nightly"))]
569            type Output = <Self as Apply<U>>::Output<fn(T) -> U>;
570
571            #[cfg(feature = "nightly")]
572            type Output = <Self as Apply<U>>::Output<impl FnMut(T) -> U>;
573
574            #[inline]
575            fn $fn(self) -> Self::Output {
576                self.apply(|x| x.$fn())
577            }
578        }
579
580        impl<S: Shape, T, U, F: FnMut(&[usize]) -> T> $trt for FromFn<S, F>
581        where
582            T: $trt<Output = U>,
583        {
584            #[cfg(not(feature = "nightly"))]
585            type Output = <Self as Apply<U>>::Output<fn(T) -> U>;
586
587            #[cfg(feature = "nightly")]
588            type Output = <Self as Apply<U>>::Output<impl FnMut(T) -> U>;
589
590            #[inline]
591            fn $fn(self) -> Self::Output {
592                self.apply(|x| x.$fn())
593            }
594        }
595
596        impl<T, B: Buffer> $trt for IntoExpr<B>
597        where
598            B::Item: $trt<Output = T>,
599        {
600            #[cfg(not(feature = "nightly"))]
601            type Output = <Self as Apply<T>>::Output<fn(B::Item) -> T>;
602
603            #[cfg(feature = "nightly")]
604            type Output = <Self as Apply<T>>::Output<impl FnMut(B::Item) -> T>;
605
606            #[inline]
607            fn $fn(self) -> Self::Output {
608                self.apply(|x| x.$fn())
609            }
610        }
611
612        impl<T, U, E: Expression, F: FnMut(E::Item) -> T> $trt for Map<E, F>
613        where
614            T: $trt<Output = U>,
615        {
616            #[cfg(not(feature = "nightly"))]
617            type Output = <Self as Apply<U>>::Output<fn(T) -> U>;
618
619            #[cfg(feature = "nightly")]
620            type Output = <Self as Apply<U>>::Output<impl FnMut(T) -> U>;
621
622            #[inline]
623            fn $fn(self) -> Self::Output {
624                self.apply(|x| x.$fn())
625            }
626        }
627
628        impl<T, S: Shape, A: Allocator> $trt for Tensor<T, S, A>
629        where
630            T: $trt<Output = T>,
631        {
632            type Output = Self;
633
634            #[inline]
635            fn $fn(self) -> Self {
636                self.apply(|x| x.$fn())
637            }
638        }
639
640        impl<'a, T, U, S: Shape, L: Layout> $trt for View<'a, T, S, L>
641        where
642            &'a T: $trt<Output = U>,
643        {
644            #[cfg(not(feature = "nightly"))]
645            type Output = <Self as Apply<U>>::Output<fn(&'a T) -> U>;
646
647            #[cfg(feature = "nightly")]
648            type Output = <Self as Apply<U>>::Output<impl FnMut(&'a T) -> U>;
649
650            #[inline]
651            fn $fn(self) -> Self::Output {
652                self.apply(|x| x.$fn())
653            }
654        }
655    };
656}
657
658impl_unary_op!(Neg, neg);
659impl_unary_op!(Not, not);