1#[cfg(feature = "nightly")]
2use alloc::alloc::Allocator;
3
4use core::ops::{
5 Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign,
6 Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
7};
8
9#[cfg(not(feature = "nightly"))]
10use crate::allocator::Allocator;
11use crate::array::Array;
12use crate::expr::{Apply, Buffer, Expression, IntoExpression};
13use crate::expr::{Fill, FillWith, FromElem, FromFn, IntoExpr, Map};
14use crate::layout::Layout;
15use crate::shape::{ConstShape, Shape};
16use crate::slice::Slice;
17use crate::tensor::Tensor;
18use crate::view::{View, ViewMut};
19
20#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
22pub struct StepRange<R, S> {
23 pub range: R,
25
26 pub step: S,
28}
29
30#[inline]
46pub fn step<R, S>(range: R, step: S) -> StepRange<R, S> {
47 StepRange { range, step }
48}
49
50impl<T: Eq, S: ConstShape> Eq for Array<T, S> {}
51impl<T: Eq, S: Shape, L: Layout> Eq for Slice<T, S, L> {}
52impl<T: Eq, S: Shape, A: Allocator> Eq for Tensor<T, S, A> {}
53impl<T: Eq, S: Shape, L: Layout> Eq for View<'_, T, S, L> {}
54impl<T: Eq, S: Shape, L: Layout> Eq for ViewMut<'_, T, S, L> {}
55
56impl<T, U, S: ConstShape, R: Shape, L: Layout, I: ?Sized> PartialEq<I> for Array<T, S>
57where
58 for<'a> &'a I: IntoExpression<IntoExpr = View<'a, U, R, L>>,
59 T: PartialEq<U>,
60{
61 #[inline]
62 fn eq(&self, other: &I) -> bool {
63 (**self).eq(other)
64 }
65}
66
67impl<T, U, S: Shape, R: Shape, L: Layout, K: Layout, I: ?Sized> PartialEq<I> for Slice<T, S, L>
68where
69 for<'a> &'a I: IntoExpression<IntoExpr = View<'a, U, R, K>>,
70 T: PartialEq<U>,
71{
72 #[inline]
73 fn eq(&self, other: &I) -> bool {
74 let other = other.into_expr();
75
76 if self.shape().with_dims(|dims| other.shape().with_dims(|other| dims == other)) {
77 #[inline]
83 fn compare_dense<T, U, S: Shape, R: Shape, L: Layout, K: Layout>(
84 this: &Slice<T, S, L>,
85 other: &Slice<U, R, K>,
86 ) -> bool
87 where
88 T: PartialEq<U>,
89 {
90 this.remap::<S, _>()[..].eq(&other.remap::<R, _>()[..])
91 }
92
93 #[inline]
94 fn compare_strided<T, U, S: Shape, R: Shape, L: Layout, K: Layout>(
95 this: &Slice<T, S, L>,
96 other: &Slice<U, R, K>,
97 ) -> bool
98 where
99 T: PartialEq<U>,
100 {
101 if this.rank() < 2 {
102 this.iter().eq(other)
103 } else {
104 this.outer_expr().into_iter().eq(other.outer_expr())
105 }
106 }
107
108 let f =
109 const { if L::IS_DENSE && K::IS_DENSE { compare_dense } else { compare_strided } };
110
111 f(self, &other)
112 } else {
113 false
114 }
115 }
116}
117
118impl<T, U, S: Shape, R: Shape, L: Layout, A: Allocator, I: ?Sized> PartialEq<I> for Tensor<T, S, A>
119where
120 for<'a> &'a I: IntoExpression<IntoExpr = View<'a, U, R, L>>,
121 T: PartialEq<U>,
122{
123 #[inline]
124 fn eq(&self, other: &I) -> bool {
125 (**self).eq(other)
126 }
127}
128
129impl<T, U, S: Shape, R: Shape, L: Layout, K: Layout, I: ?Sized> PartialEq<I> for View<'_, T, S, L>
130where
131 for<'a> &'a I: IntoExpression<IntoExpr = View<'a, U, R, K>>,
132 T: PartialEq<U>,
133{
134 #[inline]
135 fn eq(&self, other: &I) -> bool {
136 (**self).eq(other)
137 }
138}
139
140impl<T, U, S: Shape, R: Shape, L: Layout, K: Layout, I: ?Sized> PartialEq<I>
141 for ViewMut<'_, T, S, L>
142where
143 for<'a> &'a I: IntoExpression<IntoExpr = View<'a, U, R, K>>,
144 T: PartialEq<U>,
145{
146 #[inline]
147 fn eq(&self, other: &I) -> bool {
148 (**self).eq(other)
149 }
150}
151
152macro_rules! impl_binary_op {
153 ($trt:tt, $fn:tt) => {
154 impl<'a, T, U, S: ConstShape, I: Apply<U>> $trt<I> for &'a Array<T, S>
155 where
156 &'a T: $trt<I::Item, Output = U>,
157 {
158 #[cfg(not(feature = "nightly"))]
159 type Output = I::ZippedWith<Self, fn((I::Item, &'a T)) -> U>;
160
161 #[cfg(feature = "nightly")]
162 type Output = I::ZippedWith<Self, impl FnMut((I::Item, &'a T)) -> U>;
163
164 #[inline]
165 fn $fn(self, rhs: I) -> Self::Output {
166 rhs.zip_with(self, |(x, y)| y.$fn(x))
167 }
168 }
169
170 impl<'a, T, U, S: Shape, L: Layout, I: Apply<U>> $trt<I> for &'a Slice<T, S, L>
171 where
172 &'a T: $trt<I::Item, Output = U>,
173 {
174 #[cfg(not(feature = "nightly"))]
175 type Output = I::ZippedWith<Self, fn((I::Item, &'a T)) -> U>;
176
177 #[cfg(feature = "nightly")]
178 type Output = I::ZippedWith<Self, impl FnMut((I::Item, &'a T)) -> U>;
179
180 #[inline]
181 fn $fn(self, rhs: I) -> Self::Output {
182 rhs.zip_with(self, |(x, y)| y.$fn(x))
183 }
184 }
185
186 impl<'a, T, U, S: Shape, A: Allocator, I: Apply<U>> $trt<I> for &'a Tensor<T, S, A>
187 where
188 &'a T: $trt<I::Item, Output = U>,
189 {
190 #[cfg(not(feature = "nightly"))]
191 type Output = I::ZippedWith<Self, fn((I::Item, &'a T)) -> U>;
192
193 #[cfg(feature = "nightly")]
194 type Output = I::ZippedWith<Self, impl FnMut((I::Item, &'a T)) -> U>;
195
196 #[inline]
197 fn $fn(self, rhs: I) -> Self::Output {
198 rhs.zip_with(self, |(x, y)| y.$fn(x))
199 }
200 }
201
202 impl<'a, T, U, S: Shape, L: Layout, I: Apply<U>> $trt<I> for &'a View<'_, T, S, L>
203 where
204 &'a T: $trt<I::Item, Output = U>,
205 {
206 #[cfg(not(feature = "nightly"))]
207 type Output = I::ZippedWith<Self, fn((I::Item, &'a T)) -> U>;
208
209 #[cfg(feature = "nightly")]
210 type Output = I::ZippedWith<Self, impl FnMut((I::Item, &'a T)) -> U>;
211
212 #[inline]
213 fn $fn(self, rhs: I) -> Self::Output {
214 rhs.zip_with(self, |(x, y)| y.$fn(x))
215 }
216 }
217
218 impl<'a, T, U, S: Shape, L: Layout, I: Apply<U>> $trt<I> for &'a ViewMut<'_, T, S, L>
219 where
220 &'a T: $trt<I::Item, Output = U>,
221 {
222 #[cfg(not(feature = "nightly"))]
223 type Output = I::ZippedWith<Self, fn((I::Item, &'a T)) -> U>;
224
225 #[cfg(feature = "nightly")]
226 type Output = I::ZippedWith<Self, impl FnMut((I::Item, &'a T)) -> U>;
227
228 #[inline]
229 fn $fn(self, rhs: I) -> Self::Output {
230 rhs.zip_with(self, |(x, y)| y.$fn(x))
231 }
232 }
233
234 impl<T, U, S: ConstShape, I: IntoExpression> $trt<I> for Array<T, S>
235 where
236 T: $trt<I::Item, Output = U>,
237 {
238 type Output = Array<U, S>;
239
240 #[inline]
241 fn $fn(self, rhs: I) -> Self::Output {
242 self.zip_with(rhs, |(x, y)| x.$fn(y))
243 }
244 }
245
246 impl<T: Clone, U, I: Apply<U>> $trt<I> for Fill<T>
247 where
248 T: $trt<I::Item, Output = U>,
249 {
250 #[cfg(not(feature = "nightly"))]
251 type Output = I::ZippedWith<Self, fn((I::Item, T)) -> U>;
252
253 #[cfg(feature = "nightly")]
254 type Output = I::ZippedWith<Self, impl FnMut((I::Item, T)) -> U>;
255
256 #[inline]
257 fn $fn(self, rhs: I) -> Self::Output {
258 rhs.zip_with(self, |(x, y)| y.$fn(x))
259 }
260 }
261
262 impl<T: Clone, U, F: FnMut() -> T, I: Apply<U>> $trt<I> for FillWith<F>
263 where
264 T: $trt<I::Item, Output = U>,
265 {
266 #[cfg(not(feature = "nightly"))]
267 type Output = I::ZippedWith<Self, fn((I::Item, T)) -> U>;
268
269 #[cfg(feature = "nightly")]
270 type Output = I::ZippedWith<Self, impl FnMut((I::Item, T)) -> U>;
271
272 #[inline]
273 fn $fn(self, rhs: I) -> Self::Output {
274 rhs.zip_with(self, |(x, y)| y.$fn(x))
275 }
276 }
277
278 impl<S: Shape, T: Clone, U, I: Apply<U>> $trt<I> for FromElem<T, S>
279 where
280 T: $trt<I::Item, Output = U>,
281 {
282 #[cfg(not(feature = "nightly"))]
283 type Output = I::ZippedWith<Self, fn((I::Item, T)) -> U>;
284
285 #[cfg(feature = "nightly")]
286 type Output = I::ZippedWith<Self, impl FnMut((I::Item, T)) -> U>;
287
288 #[inline]
289 fn $fn(self, rhs: I) -> Self::Output {
290 rhs.zip_with(self, |(x, y)| y.$fn(x))
291 }
292 }
293
294 impl<S: Shape, T, U, F: FnMut(&[usize]) -> T, I: Apply<U>> $trt<I> for FromFn<S, F>
295 where
296 T: $trt<I::Item, Output = U>,
297 {
298 #[cfg(not(feature = "nightly"))]
299 type Output = I::ZippedWith<Self, fn((I::Item, T)) -> U>;
300
301 #[cfg(feature = "nightly")]
302 type Output = I::ZippedWith<Self, impl FnMut((I::Item, T)) -> U>;
303
304 #[inline]
305 fn $fn(self, rhs: I) -> Self::Output {
306 rhs.zip_with(self, |(x, y)| y.$fn(x))
307 }
308 }
309
310 impl<T, B: Buffer, I: Apply<T>> $trt<I> for IntoExpr<B>
311 where
312 B::Item: $trt<I::Item, Output = T>,
313 {
314 #[cfg(not(feature = "nightly"))]
315 type Output = I::ZippedWith<Self, fn((I::Item, B::Item)) -> T>;
316
317 #[cfg(feature = "nightly")]
318 type Output = I::ZippedWith<Self, impl FnMut((I::Item, B::Item)) -> T>;
319
320 #[inline]
321 fn $fn(self, rhs: I) -> Self::Output {
322 rhs.zip_with(self, |(x, y)| y.$fn(x))
323 }
324 }
325
326 impl<T, U, E: Expression, F: FnMut(E::Item) -> T, I: Apply<U>> $trt<I> for Map<E, F>
327 where
328 T: $trt<I::Item, Output = U>,
329 {
330 #[cfg(not(feature = "nightly"))]
331 type Output = I::ZippedWith<Self, fn((I::Item, T)) -> U>;
332
333 #[cfg(feature = "nightly")]
334 type Output = I::ZippedWith<Self, impl FnMut((I::Item, T)) -> U>;
335
336 #[inline]
337 fn $fn(self, rhs: I) -> Self::Output {
338 rhs.zip_with(self, |(x, y)| y.$fn(x))
339 }
340 }
341
342 impl<T, S: Shape, A: Allocator, I: IntoExpression> $trt<I> for Tensor<T, S, A>
343 where
344 T: $trt<I::Item, Output = T>,
345 {
346 type Output = Self;
347
348 #[inline]
349 fn $fn(self, rhs: I) -> Self {
350 self.zip_with(rhs, |(x, y)| x.$fn(y))
351 }
352 }
353
354 impl<'a, T, U, S: Shape, L: Layout, I: Apply<U>> $trt<I> for View<'a, T, S, L>
355 where
356 &'a T: $trt<I::Item, Output = U>,
357 {
358 #[cfg(not(feature = "nightly"))]
359 type Output = I::ZippedWith<Self, fn((I::Item, &'a T)) -> U>;
360
361 #[cfg(feature = "nightly")]
362 type Output = I::ZippedWith<Self, impl FnMut((I::Item, &'a T)) -> U>;
363
364 #[inline]
365 fn $fn(self, rhs: I) -> Self::Output {
366 rhs.zip_with(self, |(x, y)| y.$fn(x))
367 }
368 }
369 };
370}
371
372impl_binary_op!(Add, add);
373impl_binary_op!(Sub, sub);
374impl_binary_op!(Mul, mul);
375impl_binary_op!(Div, div);
376impl_binary_op!(Rem, rem);
377impl_binary_op!(BitAnd, bitand);
378impl_binary_op!(BitOr, bitor);
379impl_binary_op!(BitXor, bitxor);
380impl_binary_op!(Shl, shl);
381impl_binary_op!(Shr, shr);
382
383macro_rules! impl_op_assign {
384 ($trt:tt, $fn:tt) => {
385 impl<T, S: ConstShape, I: IntoExpression> $trt<I> for Array<T, S>
386 where
387 T: $trt<I::Item>,
388 {
389 #[inline]
390 fn $fn(&mut self, rhs: I) {
391 self.expr_mut().zip(rhs).for_each(|(x, y)| x.$fn(y));
392 }
393 }
394
395 impl<T, S: Shape, L: Layout, I: IntoExpression> $trt<I> for Slice<T, S, L>
396 where
397 T: $trt<I::Item>,
398 {
399 #[inline]
400 fn $fn(&mut self, rhs: I) {
401 self.expr_mut().zip(rhs).for_each(|(x, y)| x.$fn(y));
402 }
403 }
404
405 impl<T, S: Shape, A: Allocator, I: IntoExpression> $trt<I> for Tensor<T, S, A>
406 where
407 T: $trt<I::Item>,
408 {
409 #[inline]
410 fn $fn(&mut self, rhs: I) {
411 self.expr_mut().zip(rhs).for_each(|(x, y)| x.$fn(y));
412 }
413 }
414
415 impl<T, S: Shape, L: Layout, I: IntoExpression> $trt<I> for ViewMut<'_, T, S, L>
416 where
417 T: $trt<I::Item>,
418 {
419 #[inline]
420 fn $fn(&mut self, rhs: I) {
421 self.expr_mut().zip(rhs).for_each(|(x, y)| x.$fn(y));
422 }
423 }
424 };
425}
426
427impl_op_assign!(AddAssign, add_assign);
428impl_op_assign!(SubAssign, sub_assign);
429impl_op_assign!(MulAssign, mul_assign);
430impl_op_assign!(DivAssign, div_assign);
431impl_op_assign!(RemAssign, rem_assign);
432impl_op_assign!(BitAndAssign, bitand_assign);
433impl_op_assign!(BitOrAssign, bitor_assign);
434impl_op_assign!(BitXorAssign, bitxor_assign);
435impl_op_assign!(ShlAssign, shl_assign);
436impl_op_assign!(ShrAssign, shr_assign);
437
438macro_rules! impl_unary_op {
439 ($trt:tt, $fn:tt) => {
440 impl<'a, T, U, S: ConstShape> $trt for &'a Array<T, S>
441 where
442 &'a T: $trt<Output = U>,
443 {
444 #[cfg(not(feature = "nightly"))]
445 type Output = <Self as Apply<U>>::Output<fn(&'a T) -> U>;
446
447 #[cfg(feature = "nightly")]
448 type Output = <Self as Apply<U>>::Output<impl FnMut(&'a T) -> U>;
449
450 #[inline]
451 fn $fn(self) -> Self::Output {
452 self.apply(|x| x.$fn())
453 }
454 }
455
456 impl<'a, T, U, S: Shape, L: Layout> $trt for &'a Slice<T, S, L>
457 where
458 &'a T: $trt<Output = U>,
459 {
460 #[cfg(not(feature = "nightly"))]
461 type Output = <Self as Apply<U>>::Output<fn(&'a T) -> U>;
462
463 #[cfg(feature = "nightly")]
464 type Output = <Self as Apply<U>>::Output<impl FnMut(&'a T) -> U>;
465
466 #[inline]
467 fn $fn(self) -> Self::Output {
468 self.apply(|x| x.$fn())
469 }
470 }
471
472 impl<'a, T, U, S: Shape, A: Allocator> $trt for &'a Tensor<T, S, A>
473 where
474 &'a T: $trt<Output = U>,
475 {
476 #[cfg(not(feature = "nightly"))]
477 type Output = <Self as Apply<U>>::Output<fn(&'a T) -> U>;
478
479 #[cfg(feature = "nightly")]
480 type Output = <Self as Apply<U>>::Output<impl FnMut(&'a T) -> U>;
481
482 #[inline]
483 fn $fn(self) -> Self::Output {
484 self.apply(|x| x.$fn())
485 }
486 }
487
488 impl<'a, T, U, S: Shape, L: Layout> $trt for &'a View<'_, T, S, L>
489 where
490 &'a T: $trt<Output = U>,
491 {
492 #[cfg(not(feature = "nightly"))]
493 type Output = <Self as Apply<U>>::Output<fn(&'a T) -> U>;
494
495 #[cfg(feature = "nightly")]
496 type Output = <Self as Apply<U>>::Output<impl FnMut(&'a T) -> U>;
497
498 #[inline]
499 fn $fn(self) -> Self::Output {
500 self.apply(|x| x.$fn())
501 }
502 }
503
504 impl<'a, T, U, S: Shape, L: Layout> $trt for &'a ViewMut<'_, T, S, L>
505 where
506 &'a T: $trt<Output = U>,
507 {
508 #[cfg(not(feature = "nightly"))]
509 type Output = <Self as Apply<U>>::Output<fn(&'a T) -> U>;
510
511 #[cfg(feature = "nightly")]
512 type Output = <Self as Apply<U>>::Output<impl FnMut(&'a T) -> U>;
513
514 #[inline]
515 fn $fn(self) -> Self::Output {
516 self.apply(|x| x.$fn())
517 }
518 }
519
520 impl<T, U, S: ConstShape> $trt for Array<T, S>
521 where
522 T: $trt<Output = U>,
523 {
524 type Output = Array<U, S>;
525
526 #[inline]
527 fn $fn(self) -> Self::Output {
528 self.apply(|x| x.$fn())
529 }
530 }
531
532 impl<T: Clone, U> $trt for Fill<T>
533 where
534 T: $trt<Output = U>,
535 {
536 #[cfg(not(feature = "nightly"))]
537 type Output = <Self as Apply<U>>::Output<fn(T) -> U>;
538
539 #[cfg(feature = "nightly")]
540 type Output = <Self as Apply<U>>::Output<impl FnMut(T) -> U>;
541
542 #[inline]
543 fn $fn(self) -> Self::Output {
544 self.apply(|x| x.$fn())
545 }
546 }
547
548 impl<T: Clone, U, F: FnMut() -> T> $trt for FillWith<F>
549 where
550 T: $trt<Output = U>,
551 {
552 #[cfg(not(feature = "nightly"))]
553 type Output = <Self as Apply<U>>::Output<fn(T) -> U>;
554
555 #[cfg(feature = "nightly")]
556 type Output = <Self as Apply<U>>::Output<impl FnMut(T) -> U>;
557
558 #[inline]
559 fn $fn(self) -> Self::Output {
560 self.apply(|x| x.$fn())
561 }
562 }
563
564 impl<S: Shape, T: Clone, U> $trt for FromElem<T, S>
565 where
566 T: $trt<Output = U>,
567 {
568 #[cfg(not(feature = "nightly"))]
569 type Output = <Self as Apply<U>>::Output<fn(T) -> U>;
570
571 #[cfg(feature = "nightly")]
572 type Output = <Self as Apply<U>>::Output<impl FnMut(T) -> U>;
573
574 #[inline]
575 fn $fn(self) -> Self::Output {
576 self.apply(|x| x.$fn())
577 }
578 }
579
580 impl<S: Shape, T, U, F: FnMut(&[usize]) -> T> $trt for FromFn<S, F>
581 where
582 T: $trt<Output = U>,
583 {
584 #[cfg(not(feature = "nightly"))]
585 type Output = <Self as Apply<U>>::Output<fn(T) -> U>;
586
587 #[cfg(feature = "nightly")]
588 type Output = <Self as Apply<U>>::Output<impl FnMut(T) -> U>;
589
590 #[inline]
591 fn $fn(self) -> Self::Output {
592 self.apply(|x| x.$fn())
593 }
594 }
595
596 impl<T, B: Buffer> $trt for IntoExpr<B>
597 where
598 B::Item: $trt<Output = T>,
599 {
600 #[cfg(not(feature = "nightly"))]
601 type Output = <Self as Apply<T>>::Output<fn(B::Item) -> T>;
602
603 #[cfg(feature = "nightly")]
604 type Output = <Self as Apply<T>>::Output<impl FnMut(B::Item) -> T>;
605
606 #[inline]
607 fn $fn(self) -> Self::Output {
608 self.apply(|x| x.$fn())
609 }
610 }
611
612 impl<T, U, E: Expression, F: FnMut(E::Item) -> T> $trt for Map<E, F>
613 where
614 T: $trt<Output = U>,
615 {
616 #[cfg(not(feature = "nightly"))]
617 type Output = <Self as Apply<U>>::Output<fn(T) -> U>;
618
619 #[cfg(feature = "nightly")]
620 type Output = <Self as Apply<U>>::Output<impl FnMut(T) -> U>;
621
622 #[inline]
623 fn $fn(self) -> Self::Output {
624 self.apply(|x| x.$fn())
625 }
626 }
627
628 impl<T, S: Shape, A: Allocator> $trt for Tensor<T, S, A>
629 where
630 T: $trt<Output = T>,
631 {
632 type Output = Self;
633
634 #[inline]
635 fn $fn(self) -> Self {
636 self.apply(|x| x.$fn())
637 }
638 }
639
640 impl<'a, T, U, S: Shape, L: Layout> $trt for View<'a, T, S, L>
641 where
642 &'a T: $trt<Output = U>,
643 {
644 #[cfg(not(feature = "nightly"))]
645 type Output = <Self as Apply<U>>::Output<fn(&'a T) -> U>;
646
647 #[cfg(feature = "nightly")]
648 type Output = <Self as Apply<U>>::Output<impl FnMut(&'a T) -> U>;
649
650 #[inline]
651 fn $fn(self) -> Self::Output {
652 self.apply(|x| x.$fn())
653 }
654 }
655 };
656}
657
658impl_unary_op!(Neg, neg);
659impl_unary_op!(Not, not);