1use crate::prelude_dev::*;
2use core::mem::transmute;
3
4#[duplicate_item(
7 op op_f TensorOpAPI ;
8 [add ] [add_f ] [TensorAddAPI ];
9 [sub ] [sub_f ] [TensorSubAPI ];
10 [mul ] [mul_f ] [TensorMulAPI ];
11 [div ] [div_f ] [TensorDivAPI ];
12 [rem ] [rem_f ] [TensorRemAPI ];
13 [bitor ] [bitor_f ] [TensorBitOrAPI ];
14 [bitand] [bitand_f] [TensorBitAndAPI];
15 [bitxor] [bitxor_f] [TensorBitXorAPI];
16 [shl ] [shl_f ] [TensorShlAPI ];
17 [shr ] [shr_f ] [TensorShrAPI ];
18)]
19pub trait TensorOpAPI<TrB> {
20 type Output;
21 fn op_f(a: Self, b: TrB) -> Result<Self::Output>;
22 fn op(a: Self, b: TrB) -> Self::Output
23 where
24 Self: Sized,
25 {
26 Self::op_f(a, b).rstsr_unwrap()
27 }
28}
29
30#[duplicate_item(
31 op op_f TensorOpAPI ;
32 [add ] [add_f ] [TensorAddAPI ];
33 [sub ] [sub_f ] [TensorSubAPI ];
34 [mul ] [mul_f ] [TensorMulAPI ];
35 [div ] [div_f ] [TensorDivAPI ];
36 [rem ] [rem_f ] [TensorRemAPI ];
37 [bitor ] [bitor_f ] [TensorBitOrAPI ];
38 [bitand] [bitand_f] [TensorBitAndAPI];
39 [bitxor] [bitxor_f] [TensorBitXorAPI];
40 [shl ] [shl_f ] [TensorShlAPI ];
41 [shr ] [shr_f ] [TensorShrAPI ];
42)]
43pub fn op_f<TrA, TrB>(a: TrA, b: TrB) -> Result<TrA::Output>
44where
45 TrA: TensorOpAPI<TrB>,
46{
47 TrA::op_f(a, b)
48}
49
50#[duplicate_item(
51 op op_f TensorOpAPI ;
52 [add ] [add_f ] [TensorAddAPI ];
53 [sub ] [sub_f ] [TensorSubAPI ];
54 [mul ] [mul_f ] [TensorMulAPI ];
55 [div ] [div_f ] [TensorDivAPI ];
56 [rem ] [rem_f ] [TensorRemAPI ];
57 [bitor ] [bitor_f ] [TensorBitOrAPI ];
58 [bitand] [bitand_f] [TensorBitAndAPI];
59 [bitxor] [bitxor_f] [TensorBitXorAPI];
60 [shl ] [shl_f ] [TensorShlAPI ];
61 [shr ] [shr_f ] [TensorShrAPI ];
62)]
63pub fn op<TrA, TrB>(a: TrA, b: TrB) -> TrA::Output
64where
65 TrA: TensorOpAPI<TrB>,
66{
67 TrA::op(a, b)
68}
69
70#[duplicate_item(
71 op op_f TensorOpAPI ;
72 [add ] [add_f ] [TensorAddAPI ];
73 [sub ] [sub_f ] [TensorSubAPI ];
74 [mul ] [mul_f ] [TensorMulAPI ];
75 [div ] [div_f ] [TensorDivAPI ];
76 [rem ] [rem_f ] [TensorRemAPI ];
77 [bitor ] [bitor_f ] [TensorBitOrAPI ];
78 [bitand] [bitand_f] [TensorBitAndAPI];
79 [bitxor] [bitxor_f] [TensorBitXorAPI];
80 [shl ] [shl_f ] [TensorShlAPI ];
81 [shr ] [shr_f ] [TensorShrAPI ];
82)]
83impl<S, D> TensorBase<S, D>
84where
85 D: DimAPI,
86{
87 pub fn op_f<TrB>(&self, b: TrB) -> Result<<&Self as TensorOpAPI<TrB>>::Output>
88 where
89 for<'a> &'a Self: TensorOpAPI<TrB>,
90 {
91 <&Self as TensorOpAPI<TrB>>::op_f(self, b)
92 }
93
94 pub fn op<TrB>(&self, b: TrB) -> <&Self as TensorOpAPI<TrB>>::Output
95 where
96 for<'a> &'a Self: TensorOpAPI<TrB>,
97 {
98 <&Self as TensorOpAPI<TrB>>::op(self, b)
99 }
100}
101
102#[duplicate_item(
107 op DeviceOpAPI TensorOpAPI Op ;
108 [add ] [DeviceAddAPI ] [TensorAddAPI ] [Add ];
109 [sub ] [DeviceSubAPI ] [TensorSubAPI ] [Sub ];
110 [mul ] [DeviceMulAPI ] [TensorMulAPI ] [Mul ];
111 [div ] [DeviceDivAPI ] [TensorDivAPI ] [Div ];
112[bitor ] [DeviceBitOrAPI ] [TensorBitOrAPI ] [BitOr ];
114 [bitand] [DeviceBitAndAPI] [TensorBitAndAPI] [BitAnd];
115 [bitxor] [DeviceBitXorAPI] [TensorBitXorAPI] [BitXor];
116 [shl ] [DeviceShlAPI ] [TensorShlAPI ] [Shl ];
117 [shr ] [DeviceShrAPI ] [TensorShrAPI ] [Shr ];
118)]
119mod impl_core_ops {
120 use super::*;
121
122 impl<SA, DA, TrB> Op<TrB> for &TensorBase<SA, DA>
123 where
124 DA: DimAPI,
125 Self: TensorOpAPI<TrB>,
126 {
127 type Output = <Self as TensorOpAPI<TrB>>::Output;
128 fn op(self, b: TrB) -> Self::Output {
129 TensorOpAPI::op(self, b)
130 }
131 }
132
133 #[duplicate_item(
134 TrA; [TensorView<'_, TA, B, DA>]; [Tensor<TA, B, DA>]; [TensorCow<'_, TA, B, DA>];
135 )]
136 impl<TA, DA, B, TrB> Op<TrB> for TrA
137 where
138 DA: DimAPI,
139 B: DeviceAPI<TA>,
140 Self: TensorOpAPI<TrB>,
141 {
142 type Output = <Self as TensorOpAPI<TrB>>::Output;
143 fn op(self, b: TrB) -> Self::Output {
144 TensorOpAPI::op(self, b)
145 }
146 }
147}
148
149#[duplicate_item(
154 op_f DeviceOpAPI TensorOpAPI Op ;
155 [add_f ] [DeviceAddAPI ] [TensorAddAPI ] [Add ];
156 [sub_f ] [DeviceSubAPI ] [TensorSubAPI ] [Sub ];
157 [mul_f ] [DeviceMulAPI ] [TensorMulAPI ] [Mul ];
158 [div_f ] [DeviceDivAPI ] [TensorDivAPI ] [Div ];
159 [rem_f ] [DeviceRemAPI ] [TensorRemAPI ] [Rem ];
160 [bitor_f ] [DeviceBitOrAPI ] [TensorBitOrAPI ] [BitOr ];
161 [bitand_f] [DeviceBitAndAPI] [TensorBitAndAPI] [BitAnd];
162 [bitxor_f] [DeviceBitXorAPI] [TensorBitXorAPI] [BitXor];
163 [shl_f ] [DeviceShlAPI ] [TensorShlAPI ] [Shl ];
164 [shr_f ] [DeviceShrAPI ] [TensorShrAPI ] [Shr ];
165)]
166mod impl_binary_arithmetic_ref {
167 use super::*;
168
169 #[doc(hidden)]
170 impl<RA, RB, TA, TB, TC, DA, DB, DC, B> TensorOpAPI<&TensorAny<RB, TB, B, DB>> for &TensorAny<RA, TA, B, DA>
171 where
172 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
174 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
175 DA: DimAPI,
177 DB: DimAPI,
178 DC: DimAPI,
179 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC> + DeviceRawAPI<MaybeUninit<TC>>,
180 B: DeviceCreationAnyAPI<TC>,
181 DA: DimMaxAPI<DB, Max = DC>,
183 TA: Op<TB, Output = TC>,
185 B: DeviceOpAPI<TA, TB, TC, DC>,
186 {
187 type Output = Tensor<TC, B, DC>;
188 fn op_f(a: Self, b: &TensorAny<RB, TB, B, DB>) -> Result<Self::Output> {
189 let a = a.view();
191 let b = b.view();
192 rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
194 let la = a.layout();
195 let lb = b.layout();
196 let default_order = a.device().default_order();
197 let (la_b, lb_b) = broadcast_layout(la, lb, default_order)?;
198 let lc_from_a = layout_for_array_copy(&la_b, TensorIterOrder::default())?;
200 let lc_from_b = layout_for_array_copy(&lb_b, TensorIterOrder::default())?;
201 let lc = if lc_from_a == lc_from_b {
202 lc_from_a
203 } else {
204 match a.device().default_order() {
205 RowMajor => la_b.shape().c(),
206 ColMajor => la_b.shape().f(),
207 }
208 };
209 let device = a.device();
211 let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
212 device.op_mutc_refa_refb(storage_c.raw_mut(), &lc, a.raw(), &la_b, b.raw(), &lb_b)?;
214 let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
216 Tensor::new_f(storage_c, lc)
217 }
218 }
219
220 #[doc(hidden)]
221 #[duplicate_item(
222 RType TrA TrB a_inner b_inner ;
223 [R: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>] [&TensorAny<R, TA, B, DA> ] [TensorView<'_, TB, B, DB>] [ a ] [&b ];
224 [R: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>] [TensorView<'_, TA, B, DA>] [&TensorAny<R, TB, B, DB> ] [&a ] [ b ];
225 [ ] [TensorView<'_, TA, B, DA>] [TensorView<'_, TB, B, DB>] [&a ] [&b ];
226 )]
227 impl<TA, TB, TC, DA, DB, DC, B, RType> TensorOpAPI<TrB> for TrA
228 where
229 DA: DimAPI,
231 DB: DimAPI,
232 DC: DimAPI,
233 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
234 B: DeviceCreationAnyAPI<TC>,
235 DA: DimMaxAPI<DB, Max = DC>,
237 TA: Op<TB, Output = TC>,
239 B: DeviceOpAPI<TA, TB, TC, DC>,
240 {
241 type Output = Tensor<TC, B, DC>;
242 fn op_f(a: Self, b: TrB) -> Result<Self::Output> {
243 TensorOpAPI::op_f(a_inner, b_inner)
244 }
245 }
246}
247
248#[duplicate_item(
249 op_f DeviceOpAPI TensorOpAPI Op DeviceLConsumeAPI DeviceRConsumeAPI ;
250 [add_f ] [DeviceAddAPI ] [TensorAddAPI ] [Add ] [DeviceLConsumeAddAPI ] [DeviceRConsumeAddAPI ];
251 [sub_f ] [DeviceSubAPI ] [TensorSubAPI ] [Sub ] [DeviceLConsumeSubAPI ] [DeviceRConsumeSubAPI ];
252 [mul_f ] [DeviceMulAPI ] [TensorMulAPI ] [Mul ] [DeviceLConsumeMulAPI ] [DeviceRConsumeMulAPI ];
253 [div_f ] [DeviceDivAPI ] [TensorDivAPI ] [Div ] [DeviceLConsumeDivAPI ] [DeviceRConsumeDivAPI ];
254 [rem_f ] [DeviceRemAPI ] [TensorRemAPI ] [Rem ] [DeviceLConsumeRemAPI ] [DeviceRConsumeRemAPI ];
255 [bitor_f ] [DeviceBitOrAPI ] [TensorBitOrAPI ] [BitOr ] [DeviceLConsumeBitOrAPI ] [DeviceRConsumeBitOrAPI ];
256 [bitand_f] [DeviceBitAndAPI] [TensorBitAndAPI] [BitAnd] [DeviceLConsumeBitAndAPI] [DeviceRConsumeBitAndAPI];
257 [bitxor_f] [DeviceBitXorAPI] [TensorBitXorAPI] [BitXor] [DeviceLConsumeBitXorAPI] [DeviceRConsumeBitXorAPI];
258 [shl_f ] [DeviceShlAPI ] [TensorShlAPI ] [Shl ] [DeviceLConsumeShlAPI ] [DeviceRConsumeShlAPI ];
259 [shr_f ] [DeviceShrAPI ] [TensorShrAPI ] [Shr ] [DeviceLConsumeShrAPI ] [DeviceRConsumeShrAPI ];
260)]
261mod impl_binary_lr_consume {
262 use super::*;
263
264 #[doc(hidden)]
265 impl<RB, TA, TB, DA, DB, DC, B> TensorOpAPI<&TensorAny<RB, TB, B, DB>> for Tensor<TA, B, DA>
266 where
267 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
269 DA: DimAPI,
271 DB: DimAPI,
272 DC: DimAPI,
273 B: DeviceAPI<TA> + DeviceAPI<TB>,
274 B: DeviceCreationAnyAPI<TA>,
275 DA: DimMaxAPI<DB, Max = DC>,
277 DC: DimIntoAPI<DA>,
278 DA: DimIntoAPI<DC>,
279 TA: Op<TB, Output = TA>,
281 B: DeviceOpAPI<TA, TB, TA, DC>,
282 B: DeviceLConsumeAPI<TA, TB, DA>,
283 {
284 type Output = Tensor<TA, B, DC>;
285 fn op_f(a: Self, b: &TensorAny<RB, TB, B, DB>) -> Result<Self::Output> {
286 rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
287 let device = a.device().clone();
288 let la = a.layout();
289 let lb = b.layout();
290 let default_order = a.device().default_order();
291 let broadcast_result = broadcast_layout_to_first(la, lb, default_order);
292 if a.layout().is_broadcasted() || broadcast_result.is_err() {
293 TensorOpAPI::op_f(&a, b)
295 } else {
296 let (la_b, lb_b) = broadcast_result?;
298 if la_b != *la {
299 TensorOpAPI::op_f(&a, b)
301 } else {
302 let (mut storage_a, _) = a.into_raw_parts();
304 device.op_muta_refb(storage_a.raw_mut(), &la_b, b.raw(), &lb_b)?;
305 let c = unsafe { Tensor::new_unchecked(storage_a, la_b) };
306 c.into_dim_f::<DC>()
307 }
308 }
309 }
310 }
311
312 #[doc(hidden)]
313 impl<RA, TA, TB, DA, DB, DC, B> TensorOpAPI<Tensor<TB, B, DB>> for &TensorAny<RA, TA, B, DA>
314 where
315 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
318 DA: DimAPI,
320 DB: DimAPI,
321 DC: DimAPI,
322 B: DeviceAPI<TA> + DeviceAPI<TB>,
323 B: DeviceCreationAnyAPI<TB>,
324 DA: DimMaxAPI<DB, Max = DC>,
326 DB: DimMaxAPI<DA, Max = DC>,
327 DC: DimIntoAPI<DB>,
328 DB: DimIntoAPI<DC>,
329 TA: Op<TB, Output = TB>,
331 B: DeviceOpAPI<TA, TB, TB, DC>,
332 B: DeviceRConsumeAPI<TA, TB, DB>,
333 {
334 type Output = Tensor<TB, B, DC>;
335 fn op_f(a: Self, b: Tensor<TB, B, DB>) -> Result<Self::Output> {
336 rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
337 let device = b.device().clone();
338 let la = a.layout();
339 let lb = b.layout();
340 let default_order = b.device().default_order();
341 let broadcast_result = broadcast_layout_to_first(lb, la, default_order);
342 if b.layout().is_broadcasted() || broadcast_result.is_err() {
343 TensorOpAPI::op_f(a, &b)
345 } else {
346 let (lb_b, la_b) = broadcast_result?;
348 if lb_b != *lb {
349 TensorOpAPI::op_f(a, &b)
351 } else {
352 let (mut storage_b, _) = b.into_raw_parts();
354 device.op_muta_refb(storage_b.raw_mut(), &lb_b, a.raw(), &la_b)?;
355 let c = unsafe { Tensor::new_unchecked(storage_b, lb_b) };
356 c.into_dim_f::<DC>()
357 }
358 }
359 }
360 }
361
362 #[doc(hidden)]
363 impl<'b, TA, TB, DA, DB, DC, B> TensorOpAPI<TensorView<'b, TB, B, DB>> for Tensor<TA, B, DA>
364 where
365 DA: DimAPI,
367 DB: DimAPI,
368 DC: DimAPI,
369 B: DeviceAPI<TA> + DeviceAPI<TB>,
370 B: DeviceCreationAnyAPI<TA>,
371 DA: DimMaxAPI<DB, Max = DC>,
373 DC: DimIntoAPI<DA>,
374 DA: DimIntoAPI<DC>,
375 TA: Op<TB, Output = TA>,
377 B: DeviceOpAPI<TA, TB, TA, DC>,
378 B: DeviceLConsumeAPI<TA, TB, DA>,
379 {
380 type Output = Tensor<TA, B, DC>;
381 fn op_f(a: Self, b: TensorView<'b, TB, B, DB>) -> Result<Self::Output> {
382 TensorOpAPI::op_f(a, &b)
383 }
384 }
385
386 #[doc(hidden)]
387 impl<TA, TB, DA, DB, DC, B> TensorOpAPI<Tensor<TB, B, DB>> for TensorView<'_, TA, B, DA>
388 where
389 DA: DimAPI,
391 DB: DimAPI,
392 DC: DimAPI,
393 B: DeviceAPI<TA> + DeviceAPI<TB>,
394 B: DeviceCreationAnyAPI<TB>,
395 DA: DimMaxAPI<DB, Max = DC>,
397 DB: DimMaxAPI<DA, Max = DC>,
398 DC: DimIntoAPI<DB>,
399 DB: DimIntoAPI<DC>,
400 TA: Op<TB, Output = TB>,
402 B: DeviceOpAPI<TA, TB, TB, DC>,
403 B: DeviceRConsumeAPI<TA, TB, DB>,
404 {
405 type Output = Tensor<TB, B, DC>;
406 fn op_f(a: Self, b: Tensor<TB, B, DB>) -> Result<Self::Output> {
407 TensorOpAPI::op_f(&a, b)
408 }
409 }
410
411 #[doc(hidden)]
412 impl<T, DA, DB, DC, B> TensorOpAPI<Tensor<T, B, DB>> for Tensor<T, B, DA>
413 where
414 DA: DimAPI,
416 DB: DimAPI,
417 DC: DimAPI,
418 B: DeviceAPI<T>,
419 B: DeviceCreationAnyAPI<T>,
420 DA: DimMaxAPI<DB, Max = DC> + DimIntoAPI<DC>,
422 DB: DimMaxAPI<DA, Max = DC> + DimIntoAPI<DC>,
423 DC: DimIntoAPI<DA> + DimIntoAPI<DB>,
424 T: Op<T, Output = T>,
426 B: DeviceOpAPI<T, T, T, DC>,
427 B: DeviceLConsumeAPI<T, T, DA>,
428 B: DeviceRConsumeAPI<T, T, DB>,
429 {
430 type Output = Tensor<T, B, DC>;
431 fn op_f(a: Self, b: Tensor<T, B, DB>) -> Result<Self::Output> {
432 rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
433 let la = a.layout();
434 let lb = b.layout();
435 let default_order = a.device().default_order();
436 let broadcast_result = broadcast_layout_to_first(la, lb, default_order);
437 if !a.layout().is_broadcasted() && broadcast_result.is_ok() {
438 let (la_b, _) = broadcast_result?;
439 if la_b == *la {
440 return TensorOpAPI::op_f(a, &b);
441 }
442 }
443 let broadcast_result = broadcast_layout_to_first(lb, la, default_order);
444 if !b.layout().is_broadcasted() && broadcast_result.is_ok() {
445 let (lb_b, _) = broadcast_result?;
446 if lb_b == *lb {
447 return TensorOpAPI::op_f(&a, b);
448 }
449 }
450 return TensorOpAPI::op_f(&a, &b);
451 }
452 }
453
454 #[doc(hidden)]
458 #[duplicate_item(
459 RType TrB ;
460 [R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>] [&TensorAny<R, T, B, DB> ];
461 [ ] [TensorView<'_, T, B, DB>];
462 [ ] [Tensor<T, B, DB> ];
463 )]
464 impl<T, DA, DB, DC, B, RType> TensorOpAPI<TrB> for TensorCow<'_, T, B, DA>
465 where
466 DA: DimAPI,
468 DB: DimAPI,
469 DC: DimAPI,
470 B: DeviceAPI<T>,
471 B: DeviceCreationAnyAPI<T>,
472 DA: DimMaxAPI<DB, Max = DC> + DimIntoAPI<DC>,
474 DB: DimMaxAPI<DA, Max = DC> + DimIntoAPI<DC>,
475 DC: DimIntoAPI<DA> + DimIntoAPI<DB>,
476 T: Op<T, Output = T>,
478 B: DeviceOpAPI<T, T, T, DC>,
479 B: DeviceLConsumeAPI<T, T, DA>,
480 B: DeviceRConsumeAPI<T, T, DB>,
481 T: Clone,
483 <B as DeviceRawAPI<T>>::Raw: Clone,
484 B: OpAssignAPI<T, DA>,
485 {
486 type Output = Tensor<T, B, DC>;
487 fn op_f(a: Self, b: TrB) -> Result<Self::Output> {
488 match a.is_owned() {
489 true => TensorOpAPI::op_f(a.into_owned(), b),
490 false => TensorOpAPI::op_f(a.view(), b),
491 }
492 }
493 }
494
495 #[doc(hidden)]
496 #[duplicate_item(
497 RType TrA ;
498 [R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>] [&TensorAny<R, T, B, DA> ];
499 [ ] [TensorView<'_, T, B, DA>];
500 [ ] [Tensor<T, B, DA> ];
501 )]
502 impl<T, DA, DB, DC, B, RType> TensorOpAPI<TensorCow<'_, T, B, DB>> for TrA
503 where
504 DA: DimAPI,
506 DB: DimAPI,
507 DC: DimAPI,
508 B: DeviceAPI<T>,
509 B: DeviceCreationAnyAPI<T>,
510 DA: DimMaxAPI<DB, Max = DC> + DimIntoAPI<DC>,
512 DB: DimMaxAPI<DA, Max = DC> + DimIntoAPI<DC>,
513 DC: DimIntoAPI<DA> + DimIntoAPI<DB>,
514 T: Op<T, Output = T>,
516 B: DeviceOpAPI<T, T, T, DC>,
517 B: DeviceLConsumeAPI<T, T, DA>,
518 B: DeviceRConsumeAPI<T, T, DB>,
519 T: Clone,
521 <B as DeviceRawAPI<T>>::Raw: Clone,
522 B: OpAssignAPI<T, DB>,
523 {
524 type Output = Tensor<T, B, DC>;
525 fn op_f(a: Self, b: TensorCow<'_, T, B, DB>) -> Result<Self::Output> {
526 match b.is_owned() {
527 true => TensorOpAPI::op_f(a, b.into_owned()),
528 false => TensorOpAPI::op_f(a, b.view()),
529 }
530 }
531 }
532
533 impl<T, DA, DB, DC, B> TensorOpAPI<TensorCow<'_, T, B, DB>> for TensorCow<'_, T, B, DA>
534 where
535 DA: DimAPI,
537 DB: DimAPI,
538 DC: DimAPI,
539 B: DeviceAPI<T>,
540 B: DeviceCreationAnyAPI<T>,
541 DA: DimMaxAPI<DB, Max = DC> + DimIntoAPI<DC>,
543 DB: DimMaxAPI<DA, Max = DC> + DimIntoAPI<DC>,
544 DC: DimIntoAPI<DA> + DimIntoAPI<DB>,
545 T: Op<T, Output = T>,
547 B: DeviceOpAPI<T, T, T, DC>,
548 B: DeviceLConsumeAPI<T, T, DA>,
549 B: DeviceRConsumeAPI<T, T, DB>,
550 T: Clone,
552 <B as DeviceRawAPI<T>>::Raw: Clone,
553 B: OpAssignAPI<T, DA> + OpAssignAPI<T, DB>,
554 {
555 type Output = Tensor<T, B, DC>;
556 fn op_f(a: Self, b: TensorCow<'_, T, B, DB>) -> Result<Self::Output> {
557 match (a.is_owned(), b.is_owned()) {
558 (true, true) => TensorOpAPI::op_f(a.into_owned(), b.into_owned()),
559 (true, false) => TensorOpAPI::op_f(a.into_owned(), b.view()),
560 (false, true) => TensorOpAPI::op_f(a.view(), b.into_owned()),
561 (false, false) => TensorOpAPI::op_f(a.view(), b.view()),
562 }
563 }
564 }
565}
566
567#[duplicate_item(
572 op op_f DeviceOpAPI Op ;
573 [add_with_output ] [add_with_output_f ] [DeviceAddAPI ] [Add ];
574 [sub_with_output ] [sub_with_output_f ] [DeviceSubAPI ] [Sub ];
575 [mul_with_output ] [mul_with_output_f ] [DeviceMulAPI ] [Mul ];
576 [div_with_output ] [div_with_output_f ] [DeviceDivAPI ] [Div ];
577 [rem_with_output ] [rem_with_output_f ] [DeviceRemAPI ] [Rem ];
578 [bitor_with_output ] [bitor_with_output_f ] [DeviceBitOrAPI ] [BitOr ];
579 [bitand_with_output] [bitand_with_output_f] [DeviceBitAndAPI] [BitAnd];
580 [bitxor_with_output] [bitxor_with_output_f] [DeviceBitXorAPI] [BitXor];
581 [shl_with_output ] [shl_with_output_f ] [DeviceShlAPI ] [Shl ];
582 [shr_with_output ] [shr_with_output_f ] [DeviceShrAPI ] [Shr ];
583)]
584pub fn op_f<TrA, TrB, TrC, TA, TB, TC, DA, DB, DC, B>(a: TrA, b: TrB, mut c: TrC) -> Result<()>
585where
586 TrA: TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
588 TrB: TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
589 TrC: TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
590 DA: DimAPI,
592 DB: DimAPI,
593 DC: DimAPI,
594 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
595 DC: DimMaxAPI<DA, Max = DC> + DimMaxAPI<DB, Max = DC>,
597 TA: Op<TB, Output = TC>,
599 B: DeviceOpAPI<TA, TB, TC, DC>,
600{
601 let a = a.view();
603 let b = b.view();
604 let mut c = c.view_mut();
605 rstsr_assert!(c.device().same_device(a.device()), DeviceMismatch)?;
607 rstsr_assert!(c.device().same_device(b.device()), DeviceMismatch)?;
608 let lc = c.layout();
609 let la = a.layout();
610 let lb = b.layout();
611 let default_order = c.device().default_order();
612 let (lc_b, la_b) = broadcast_layout_to_first(lc, la, default_order)?;
615 rstsr_assert_eq!(lc_b, *lc, InvalidLayout)?;
616 let (lc_b, lb_b) = broadcast_layout_to_first(lc, lb, default_order)?;
617 rstsr_assert_eq!(lc_b, *lc, InvalidLayout)?;
618 let device = c.device().clone();
620 let c_raw_mut = unsafe {
622 transmute::<&mut <B as DeviceRawAPI<TC>>::Raw, &mut <B as DeviceRawAPI<MaybeUninit<TC>>>::Raw>(c.raw_mut())
623 };
624 device.op_mutc_refa_refb(c_raw_mut, &lc_b, a.raw(), &la_b, b.raw(), &lb_b)
625}
626
627#[duplicate_item(
628 op op_f DeviceOpAPI Op ;
629 [add_with_output ] [add_with_output_f ] [DeviceAddAPI ] [Add ];
630 [sub_with_output ] [sub_with_output_f ] [DeviceSubAPI ] [Sub ];
631 [mul_with_output ] [mul_with_output_f ] [DeviceMulAPI ] [Mul ];
632 [div_with_output ] [div_with_output_f ] [DeviceDivAPI ] [Div ];
633 [rem_with_output ] [rem_with_output_f ] [DeviceRemAPI ] [Rem ];
634 [bitor_with_output ] [bitor_with_output_f ] [DeviceBitOrAPI ] [BitOr ];
635 [bitand_with_output] [bitand_with_output_f] [DeviceBitAndAPI] [BitAnd];
636 [bitxor_with_output] [bitxor_with_output_f] [DeviceBitXorAPI] [BitXor];
637 [shl_with_output ] [shl_with_output_f ] [DeviceShlAPI ] [Shl ];
638 [shr_with_output ] [shr_with_output_f ] [DeviceShrAPI ] [Shr ];
639 )]
640pub fn op<TrA, TrB, TrC, TA, TB, TC, DA, DB, DC, B>(a: TrA, b: TrB, c: TrC)
641where
642 TrA: TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
644 TrB: TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
645 TrC: TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
646 DA: DimAPI,
648 DB: DimAPI,
649 DC: DimAPI,
650 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
651 DC: DimMaxAPI<DA, Max = DC> + DimMaxAPI<DB, Max = DC>,
653 TA: Op<TB, Output = TC>,
655 B: DeviceOpAPI<TA, TB, TC, DC>,
656{
657 op_f(a, b, c).rstsr_unwrap()
658}
659
660macro_rules! impl_arithmetic_scalar_lhs {
665 ($ty: ty, $op: ident, $op_f: ident, $Op: ident, $DeviceOpAPI: ident, $TensorOpAPI: ident, $DeviceRConsumeOpAPI: ident) => {
666 #[doc(hidden)]
667 impl<T, R, D, B> $TensorOpAPI<&TensorAny<R, T, B, D>> for $ty
668 where
669 T: From<$ty> + $Op<T, Output = T>,
670 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
671 D: DimAPI,
672 B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
673 B: $DeviceOpAPI<T, T, T, D>,
674 {
675 type Output = Tensor<T, B, D>;
676 fn $op_f(a: Self, b: &TensorAny<R, T, B, D>) -> Result<Self::Output> {
677 let a = T::from(a);
678 let device = b.device();
679 let lb = b.layout();
680 let lc = layout_for_array_copy(lb, TensorIterOrder::default())?;
681 let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
682 device.op_mutc_numa_refb(storage_c.raw_mut(), &lc, a, b.raw(), lb)?;
683 let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
684 Tensor::new_f(storage_c, lc)
685 }
686 }
687
688 #[doc(hidden)]
689 impl<T, R, D, B> $Op<&TensorAny<R, T, B, D>> for $ty
690 where
691 T: From<$ty> + $Op<T, Output = T>,
692 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
693 D: DimAPI,
694 B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
695 B: $DeviceOpAPI<T, T, T, D>,
696 {
697 type Output = Tensor<T, B, D>;
698 fn $op(self, rhs: &TensorAny<R, T, B, D>) -> Self::Output {
699 $TensorOpAPI::$op_f(self, rhs).rstsr_unwrap()
700 }
701 }
702
703 #[doc(hidden)]
704 impl<T, B, D> $TensorOpAPI<TensorView<'_, T, B, D>> for $ty
705 where
706 T: From<$ty> + $Op<T, Output = T>,
707 D: DimAPI,
708 B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
709 B: $DeviceOpAPI<T, T, T, D>,
710 {
711 type Output = Tensor<T, B, D>;
712 fn $op_f(a: Self, b: TensorView<'_, T, B, D>) -> Result<Self::Output> {
713 let a = T::from(a);
714 let device = b.device();
715 let lb = b.layout();
716 let lc = layout_for_array_copy(lb, TensorIterOrder::default())?;
717 let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
718 device.op_mutc_numa_refb(storage_c.raw_mut(), &lc, a, b.raw(), lb)?;
719 let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
720 Tensor::new_f(storage_c, lc)
721 }
722 }
723
724 #[doc(hidden)]
725 impl<T, B, D> $Op<TensorView<'_, T, B, D>> for $ty
726 where
727 T: From<$ty> + $Op<T, Output = T>,
728 D: DimAPI,
729 B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
730 B: $DeviceOpAPI<T, T, T, D>,
731 {
732 type Output = Tensor<T, B, D>;
733 fn $op(self, rhs: TensorView<'_, T, B, D>) -> Self::Output {
734 $TensorOpAPI::$op_f(self, rhs).rstsr_unwrap()
735 }
736 }
737
738 #[doc(hidden)]
739 impl<T, B, D> $TensorOpAPI<Tensor<T, B, D>> for $ty
740 where
741 T: From<$ty> + $Op<T, Output = T>,
742 D: DimAPI,
743 B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
744 B: $DeviceRConsumeOpAPI<T, T, D>,
745 {
746 type Output = Tensor<T, B, D>;
747 fn $op_f(a: Self, mut b: Tensor<T, B, D>) -> Result<Self::Output> {
748 let a = T::from(a);
749 let device = b.device().clone();
750 let lb = b.layout().clone();
751 device.op_muta_numb(b.raw_mut(), &lb, a)?;
752 return Ok(b);
753 }
754 }
755
756 #[doc(hidden)]
757 impl<T, B, D> $Op<Tensor<T, B, D>> for $ty
758 where
759 T: From<$ty> + $Op<T, Output = T>,
760 D: DimAPI,
761 B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
762 B: $DeviceRConsumeOpAPI<T, T, D>,
763 {
764 type Output = Tensor<T, B, D>;
765 fn $op(self, rhs: Tensor<T, B, D>) -> Self::Output {
766 $TensorOpAPI::$op_f(self, rhs).rstsr_unwrap()
767 }
768 }
769
770 #[doc(hidden)]
771 impl<T, B, D> $TensorOpAPI<TensorCow<'_, T, B, D>> for $ty
772 where
773 T: From<$ty> + $Op<T, Output = T>,
774 D: DimAPI,
775 B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
776 B: $DeviceRConsumeOpAPI<T, T, D> + $DeviceOpAPI<T, T, T, D>,
777 T: Clone,
779 <B as DeviceRawAPI<T>>::Raw: Clone,
780 B: OpAssignAPI<T, D>,
781 {
782 type Output = Tensor<T, B, D>;
783 fn $op_f(a: Self, b: TensorCow<'_, T, B, D>) -> Result<Self::Output> {
784 match b.is_owned() {
785 true => $TensorOpAPI::$op_f(a, b.into_owned()),
786 false => $TensorOpAPI::$op_f(a, b.view()),
787 }
788 }
789 }
790
791 #[doc(hidden)]
792 impl<T, B, D> $Op<TensorCow<'_, T, B, D>> for $ty
793 where
794 T: From<$ty> + $Op<T, Output = T>,
795 D: DimAPI,
796 B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
797 B: $DeviceRConsumeOpAPI<T, T, D> + $DeviceOpAPI<T, T, T, D>,
798 T: Clone,
800 <B as DeviceRawAPI<T>>::Raw: Clone,
801 B: OpAssignAPI<T, D>,
802 {
803 type Output = Tensor<T, B, D>;
804 fn $op(self, rhs: TensorCow<'_, T, B, D>) -> Self::Output {
805 $TensorOpAPI::$op_f(self, rhs).rstsr_unwrap()
806 }
807 }
808 };
809}
810
811#[rustfmt::skip]
812macro_rules! impl_arithmetic_scalar_lhs_all {
813 ($ty: ty) => {
814 impl_arithmetic_scalar_lhs!($ty, add , add_f , Add , DeviceAddAPI , TensorAddAPI , DeviceRConsumeAddAPI );
815 impl_arithmetic_scalar_lhs!($ty, sub , sub_f , Sub , DeviceSubAPI , TensorSubAPI , DeviceRConsumeSubAPI );
816 impl_arithmetic_scalar_lhs!($ty, mul , mul_f , Mul , DeviceMulAPI , TensorMulAPI , DeviceRConsumeMulAPI );
817 impl_arithmetic_scalar_lhs!($ty, div , div_f , Div , DeviceDivAPI , TensorDivAPI , DeviceRConsumeDivAPI );
818 impl_arithmetic_scalar_lhs!($ty, rem , rem_f , Rem , DeviceRemAPI , TensorRemAPI , DeviceRConsumeRemAPI );
819 impl_arithmetic_scalar_lhs!($ty, bitor , bitor_f , BitOr , DeviceBitOrAPI , TensorBitOrAPI , DeviceRConsumeBitOrAPI );
820 impl_arithmetic_scalar_lhs!($ty, bitand, bitand_f, BitAnd, DeviceBitAndAPI, TensorBitAndAPI, DeviceRConsumeBitAndAPI);
821 impl_arithmetic_scalar_lhs!($ty, bitxor, bitxor_f, BitXor, DeviceBitXorAPI, TensorBitXorAPI, DeviceRConsumeBitXorAPI);
822 impl_arithmetic_scalar_lhs!($ty, shl , shl_f , Shl , DeviceShlAPI , TensorShlAPI , DeviceRConsumeShlAPI );
823 impl_arithmetic_scalar_lhs!($ty, shr , shr_f , Shr , DeviceShrAPI , TensorShrAPI , DeviceRConsumeShrAPI );
824 };
825}
826
827#[rustfmt::skip]
828macro_rules! impl_arithmetic_scalar_lhs_bool {
829 ($ty: ty) => {
830 impl_arithmetic_scalar_lhs!($ty, bitor , bitor_f , BitOr , DeviceBitOrAPI , TensorBitOrAPI , DeviceRConsumeBitOrAPI );
831 impl_arithmetic_scalar_lhs!($ty, bitand, bitand_f, BitAnd, DeviceBitAndAPI, TensorBitAndAPI, DeviceRConsumeBitAndAPI);
832 impl_arithmetic_scalar_lhs!($ty, bitxor, bitxor_f, BitXor, DeviceBitXorAPI, TensorBitXorAPI, DeviceRConsumeBitXorAPI);
833 };
834}
835
836#[rustfmt::skip]
837macro_rules! impl_arithmetic_scalar_lhs_float {
838 ($ty: ty) => {
839 impl_arithmetic_scalar_lhs!($ty, add , add_f , Add , DeviceAddAPI , TensorAddAPI , DeviceRConsumeAddAPI );
840 impl_arithmetic_scalar_lhs!($ty, sub , sub_f , Sub , DeviceSubAPI , TensorSubAPI , DeviceRConsumeSubAPI );
841 impl_arithmetic_scalar_lhs!($ty, mul , mul_f , Mul , DeviceMulAPI , TensorMulAPI , DeviceRConsumeMulAPI );
842 impl_arithmetic_scalar_lhs!($ty, div , div_f , Div , DeviceDivAPI , TensorDivAPI , DeviceRConsumeDivAPI );
843 };
844}
845
846mod impl_arithmetic_scalar_lhs {
847 use super::*;
848 use half::{bf16, f16};
849 use num::complex::Complex;
850 impl_arithmetic_scalar_lhs_all!(i8);
851 impl_arithmetic_scalar_lhs_all!(u8);
852 impl_arithmetic_scalar_lhs_all!(i16);
853 impl_arithmetic_scalar_lhs_all!(u16);
854 impl_arithmetic_scalar_lhs_all!(i32);
855 impl_arithmetic_scalar_lhs_all!(u32);
856 impl_arithmetic_scalar_lhs_all!(i64);
857 impl_arithmetic_scalar_lhs_all!(u64);
858 impl_arithmetic_scalar_lhs_all!(i128);
859 impl_arithmetic_scalar_lhs_all!(u128);
860 impl_arithmetic_scalar_lhs_all!(isize);
861 impl_arithmetic_scalar_lhs_all!(usize);
862
863 impl_arithmetic_scalar_lhs_bool!(bool);
864
865 impl_arithmetic_scalar_lhs_float!(bf16);
866 impl_arithmetic_scalar_lhs_float!(f16);
867 impl_arithmetic_scalar_lhs_float!(f32);
868 impl_arithmetic_scalar_lhs_float!(f64);
869 impl_arithmetic_scalar_lhs_float!(Complex<bf16>);
870 impl_arithmetic_scalar_lhs_float!(Complex<f16>);
871 impl_arithmetic_scalar_lhs_float!(Complex<f32>);
872 impl_arithmetic_scalar_lhs_float!(Complex<f64>);
873}
874
875#[duplicate_item(
883 op_f Op DeviceOpAPI TensorOpAPI DeviceLConsumeOpAPI ;
884 [add_f ] [Add ] [DeviceAddAPI ] [TensorAddAPI ] [DeviceLConsumeAddAPI ];
885 [sub_f ] [Sub ] [DeviceSubAPI ] [TensorSubAPI ] [DeviceLConsumeSubAPI ];
886 [mul_f ] [Mul ] [DeviceMulAPI ] [TensorMulAPI ] [DeviceLConsumeMulAPI ];
887 [div_f ] [Div ] [DeviceDivAPI ] [TensorDivAPI ] [DeviceLConsumeDivAPI ];
888 [rem_f ] [Rem ] [DeviceRemAPI ] [TensorRemAPI ] [DeviceLConsumeRemAPI ];
889 [bitor_f ] [BitOr ] [DeviceBitOrAPI ] [TensorBitOrAPI ] [DeviceLConsumeBitOrAPI ];
890 [bitand_f] [BitAnd] [DeviceBitAndAPI] [TensorBitAndAPI] [DeviceLConsumeBitAndAPI];
891 [bitxor_f] [BitXor] [DeviceBitXorAPI] [TensorBitXorAPI] [DeviceLConsumeBitXorAPI];
892 [shl_f ] [Shl ] [DeviceShlAPI ] [TensorShlAPI ] [DeviceLConsumeShlAPI ];
893 [shr_f ] [Shr ] [DeviceShrAPI ] [TensorShrAPI ] [DeviceLConsumeShrAPI ];
894)]
895mod impl_arithmetic_scalar_rhs {
896 use super::*;
897
898 #[doc(hidden)]
899 impl<T, TB, R, D, B> TensorOpAPI<TB> for &TensorAny<R, T, B, D>
900 where
901 T: From<TB> + Op<T, Output = T>,
902 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
903 D: DimAPI,
904 B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
905 B: DeviceOpAPI<T, T, T, D>,
906 TB: num::Num,
908 {
909 type Output = Tensor<T, B, D>;
910 fn op_f(a: Self, b: TB) -> Result<Self::Output> {
911 let b = T::from(b);
912 let device = a.device();
913 let la = a.layout();
914 let lc = layout_for_array_copy(la, TensorIterOrder::default())?;
915 let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
916 device.op_mutc_refa_numb(storage_c.raw_mut(), &lc, a.raw(), la, b)?;
917 let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
918 Tensor::new_f(storage_c, lc)
919 }
920 }
921
922 #[doc(hidden)]
923 impl<T, TB, D, B> TensorOpAPI<TB> for TensorView<'_, T, B, D>
924 where
925 T: From<TB> + Op<T, Output = T>,
926 D: DimAPI,
927 B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
928 B: DeviceOpAPI<T, T, T, D>,
929 TB: num::Num,
931 {
932 type Output = Tensor<T, B, D>;
933 fn op_f(a: Self, b: TB) -> Result<Self::Output> {
934 let b = T::from(b);
935 let device = a.device();
936 let la = a.layout();
937 let lc = layout_for_array_copy(la, TensorIterOrder::default())?;
938 let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
939 device.op_mutc_refa_numb(storage_c.raw_mut(), &lc, a.raw(), la, b)?;
940 let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
941 Tensor::new_f(storage_c, lc)
942 }
943 }
944
945 #[doc(hidden)]
946 impl<T, TB, D, B> TensorOpAPI<TB> for Tensor<T, B, D>
947 where
948 T: From<TB> + Op<T, Output = T>,
949 D: DimAPI,
950 B: DeviceAPI<T>,
951 B: DeviceLConsumeOpAPI<T, T, D>,
952 TB: num::Num,
954 {
955 type Output = Tensor<T, B, D>;
956 fn op_f(mut a: Self, b: TB) -> Result<Self::Output> {
957 let b = T::from(b);
958 let device = a.device().clone();
959 let la = a.layout().clone();
960 device.op_muta_numb(a.raw_mut(), &la, b)?;
961 return Ok(a);
962 }
963 }
964
965 #[doc(hidden)]
966 impl<T, TB, D, B> TensorOpAPI<TB> for TensorCow<'_, T, B, D>
967 where
968 T: From<TB> + Op<T, Output = T>,
969 D: DimAPI,
970 B: DeviceAPI<T>,
971 B: DeviceLConsumeOpAPI<T, T, D> + DeviceOpAPI<T, T, T, D>,
972 TB: num::Num,
974 T: Clone,
976 <B as DeviceRawAPI<T>>::Raw: Clone,
977 B: DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
978 {
979 type Output = Tensor<T, B, D>;
980 fn op_f(a: Self, b: TB) -> Result<Self::Output> {
981 match a.is_owned() {
982 true => TensorOpAPI::op_f(a.into_owned(), b),
983 false => TensorOpAPI::op_f(a.view(), b),
984 }
985 }
986 }
987}
988
989#[cfg(test)]
994mod test {
995 use super::*;
996
997 #[test]
998 #[cfg(not(feature = "col_major"))]
999 fn test_add_row_major() {
1000 let a = linspace((1.0, 5.0, 5));
1002 let b = linspace((2.0, 10.0, 5));
1003 let c = &a + &b;
1004 let c_ref = vec![3., 6., 9., 12., 15.].into();
1005 assert!(allclose_f64(&c, &c_ref));
1006
1007 let a = linspace((1.0, 5.0, 5));
1008 let b = linspace((2.0, 10.0, 5));
1009 let c = add(&a, &b);
1010 let c_ref = vec![3., 6., 9., 12., 15.].into();
1011 assert!(allclose_f64(&c, &c_ref));
1012
1013 let a = linspace((1.0, 6.0, 6)).into_shape_assume_contig([2, 3]);
1016 let b = linspace((2.0, 6.0, 3));
1017 let c = &a + &b;
1018 let c_ref = vec![3., 6., 9., 6., 9., 12.].into();
1019 assert!(allclose_f64(&c, &c_ref));
1020
1021 let a = linspace((1.0, 6.0, 6));
1026 let a = a.into_shape_assume_contig([1, 2, 3]);
1027 let b = linspace((1.0, 10.0, 10));
1028 let b = b.into_shape_assume_contig([5, 1, 2, 1]);
1029 let c = &a + &b;
1030 let c_ref = vec![
1031 2., 3., 4., 6., 7., 8., 4., 5., 6., 8., 9., 10., 6., 7., 8., 10., 11., 12., 8., 9., 10., 12., 13., 14.,
1032 10., 11., 12., 14., 15., 16.,
1033 ];
1034 let c_ref = c_ref.into();
1035 assert!(allclose_f64(&c, &c_ref));
1036
1037 let a = linspace((1.0, 9.0, 9));
1039 let a = a.into_shape_assume_contig([3, 3]);
1040 let b = linspace((2.0, 18.0, 9));
1041 let b = b.into_shape_assume_contig([3, 3]).into_reverse_axes();
1042 let c = &a + &b;
1043 let c_ref = vec![3., 10., 17., 8., 15., 22., 13., 20., 27.].into();
1044 assert!(allclose_f64(&c, &c_ref));
1045
1046 let a = linspace((1.0, 5.0, 5));
1048 let b = linspace((2.0, 10.0, 5));
1049 let a = a.flip(0);
1050 let c = &a + &b;
1051 let c_ref = vec![7., 8., 9., 10., 11.].into();
1052 assert!(allclose_f64(&c, &c_ref));
1053
1054 let a = linspace((1.0, 5.0, 5));
1055 let b = linspace((2.0, 10.0, 5));
1056 let b = b.flip(0);
1057 let c = &a + &b;
1058 let c_ref = vec![11., 10., 9., 8., 7.].into();
1059 assert!(allclose_f64(&c, &c_ref));
1060
1061 let a = linspace((1.0, 5.0, 5));
1063 let b = linspace((2.0, 10.0, 5));
1064 let c = a.view() + &b;
1065 let c_ref = vec![3., 6., 9., 12., 15.].into();
1066 assert!(allclose_f64(&c, &c_ref));
1067
1068 let a = linspace((1.0, 5.0, 5));
1069 let b = linspace((2.0, 10.0, 5));
1070 let c = &a + b.view();
1071 let c_ref = vec![3., 6., 9., 12., 15.].into();
1072 assert!(allclose_f64(&c, &c_ref));
1073 }
1074
1075 #[test]
1076 #[cfg(feature = "col_major")]
1077 fn test_add_col_major() {
1078 let a = linspace((1.0, 5.0, 5));
1080 let b = linspace((2.0, 10.0, 5));
1081 let c = &a + &b;
1082 let c_ref = vec![3., 6., 9., 12., 15.].into();
1083 assert!(allclose_f64(&c, &c_ref));
1084
1085 let a = linspace((1.0, 5.0, 5));
1086 let b = linspace((2.0, 10.0, 5));
1087 let c = add(&a, &b);
1088 let c_ref = vec![3., 6., 9., 12., 15.].into();
1089 assert!(allclose_f64(&c, &c_ref));
1090
1091 let a = linspace((1.0, 6.0, 6)).into_shape_assume_contig([3, 2]);
1094 let b = linspace((2.0, 6.0, 3));
1095 let c = &a + &b;
1096 let c_ref = vec![3., 6., 9., 6., 9., 12.];
1097 assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1098
1099 let a = linspace((1.0, 6.0, 6));
1102 let a = a.into_shape_assume_contig([3, 2, 1]);
1103 let b = linspace((1.0, 10.0, 10));
1104 let b = b.into_shape_assume_contig([1, 2, 1, 5]);
1105 let c = &a + &b;
1106 let c_ref = vec![
1107 2., 3., 4., 6., 7., 8., 4., 5., 6., 8., 9., 10., 6., 7., 8., 10., 11., 12., 8., 9., 10., 12., 13., 14.,
1108 10., 11., 12., 14., 15., 16.,
1109 ];
1110 assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1111
1112 let a = linspace((1.0, 9.0, 9));
1114 let a = a.into_shape_assume_contig([3, 3]);
1115 let b = linspace((2.0, 18.0, 9));
1116 let b = b.into_shape_assume_contig([3, 3]).into_reverse_axes();
1117 let c = &a + &b;
1118 let c_ref = vec![3., 10., 17., 8., 15., 22., 13., 20., 27.];
1119 assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1120
1121 let a = linspace((1.0, 5.0, 5));
1123 let b = linspace((2.0, 10.0, 5));
1124 let a = a.flip(0);
1125 let c = &a + &b;
1126 let c_ref = vec![7., 8., 9., 10., 11.];
1127 assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1128
1129 let a = linspace((1.0, 5.0, 5));
1130 let b = linspace((2.0, 10.0, 5));
1131 let b = b.flip(0);
1132 let c = &a + &b;
1133 let c_ref = vec![11., 10., 9., 8., 7.];
1134 assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1135
1136 let a = linspace((1.0, 5.0, 5));
1138 let b = linspace((2.0, 10.0, 5));
1139 let c = a.view() + &b;
1140 let c_ref = vec![3., 6., 9., 12., 15.];
1141 assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1142
1143 let a = linspace((1.0, 5.0, 5));
1144 let b = linspace((2.0, 10.0, 5));
1145 let c = &a + b.view();
1146 let c_ref = vec![3., 6., 9., 12., 15.];
1147 assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1148 }
1149
1150 #[test]
1151 fn test_sub() {
1152 let a = linspace((1.0, 5.0, 5));
1154 let b = linspace((2.0, 10.0, 5));
1155 let c = &a - &b;
1156 let c_ref = vec![-1., -2., -3., -4., -5.].into();
1157 assert!(allclose_f64(&c, &c_ref));
1158 }
1159
1160 #[test]
1161 fn test_mul() {
1162 let a = linspace((1.0, 5.0, 5));
1164 let b = linspace((2.0, 10.0, 5));
1165 let c = &a * &b;
1166 let c_ref = vec![2., 8., 18., 32., 50.].into();
1167 assert!(allclose_f64(&c, &c_ref));
1168 }
1169
1170 #[test]
1171 #[cfg(not(feature = "col_major"))]
1172 fn test_add_consume_row_major() {
1173 let a = linspace((1.0, 5.0, 5));
1175 let b = linspace((2.0, 10.0, 5));
1176 let a_ptr = a.raw().as_ptr();
1177 let c = a + &b;
1178 let c_ptr = c.raw().as_ptr();
1179 let c_ref = vec![3., 6., 9., 12., 15.].into();
1180 assert!(allclose_f64(&c, &c_ref));
1181 assert_eq!(a_ptr, c_ptr);
1182 let a = linspace((1.0, 10.0, 10)).into_shape_assume_contig([2, 5]);
1184 let b = linspace((2.0, 10.0, 5));
1185 let a_ptr = a.raw().as_ptr();
1186 let c = a + &b;
1187 let c_ptr = c.raw().as_ptr();
1188 let c_ref = vec![3., 6., 9., 12., 15., 8., 11., 14., 17., 20.].into();
1189 assert!(allclose_f64(&c, &c_ref));
1190 assert_eq!(a_ptr, c_ptr);
1191 let a = linspace((2.0, 10.0, 5));
1193 let b = linspace((1.0, 10.0, 10)).into_shape_assume_contig([2, 5]);
1194 let a_ptr = a.raw().as_ptr();
1195 let c = a + &b;
1196 let c_ptr = c.raw().as_ptr();
1197 let c_ref = vec![3., 6., 9., 12., 15., 8., 11., 14., 17., 20.].into();
1198 assert!(allclose_f64(&c, &c_ref));
1199 assert_ne!(a_ptr, c_ptr);
1200 let a = linspace((1.0, 5.0, 5));
1202 let b = linspace((2.0, 10.0, 5));
1203 let b_ptr = b.raw().as_ptr();
1204 let c = &a + b;
1205 let c_ptr = c.raw().as_ptr();
1206 let c_ref = vec![3., 6., 9., 12., 15.].into();
1207 assert!(allclose_f64(&c, &c_ref));
1208 assert_eq!(b_ptr, c_ptr);
1209 let a = linspace((1.0, 10.0, 10)).into_shape_assume_contig([2, 5]);
1211 let b = linspace((2.0, 10.0, 5));
1212 let b_ptr = b.raw().as_ptr();
1213 let c = &a + b;
1214 let c_ptr = c.raw().as_ptr();
1215 let c_ref = vec![3., 6., 9., 12., 15., 8., 11., 14., 17., 20.].into();
1216 assert!(allclose_f64(&c, &c_ref));
1217 assert_ne!(b_ptr, c_ptr);
1218 let a = linspace((1.0, 5.0, 5));
1220 let b = linspace((2.0, 10.0, 5));
1221 let a_ptr = a.raw().as_ptr();
1222 let c = a + b;
1223 let c_ptr = c.raw().as_ptr();
1224 let c_ref = vec![3., 6., 9., 12., 15.].into();
1225 assert!(allclose_f64(&c, &c_ref));
1226 assert_eq!(a_ptr, c_ptr);
1227 }
1228
1229 #[test]
1230 #[cfg(feature = "col_major")]
1231 fn test_add_consume_col_major() {
1232 let a = linspace((1.0, 5.0, 5));
1234 let b = linspace((2.0, 10.0, 5));
1235 let a_ptr = a.raw().as_ptr();
1236 let c = a + &b;
1237 let c_ptr = c.raw().as_ptr();
1238 let c_ref = vec![3., 6., 9., 12., 15.];
1239 assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1240 assert_eq!(a_ptr, c_ptr);
1241 let a = linspace((1.0, 10.0, 10)).into_shape_assume_contig([5, 2]);
1243 let b = linspace((2.0, 10.0, 5));
1244 let a_ptr = a.raw().as_ptr();
1245 let c = a + &b;
1246 let c_ptr = c.raw().as_ptr();
1247 let c_ref = vec![3., 6., 9., 12., 15., 8., 11., 14., 17., 20.];
1248 assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1249 assert_eq!(a_ptr, c_ptr);
1250 let a = linspace((2.0, 10.0, 5));
1252 let b = linspace((1.0, 10.0, 10)).into_shape_assume_contig([5, 2]);
1253 let a_ptr = a.raw().as_ptr();
1254 let c = a + &b;
1255 let c_ptr = c.raw().as_ptr();
1256 let c_ref = vec![3., 6., 9., 12., 15., 8., 11., 14., 17., 20.];
1257 assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1258 assert_ne!(a_ptr, c_ptr);
1259 let a = linspace((1.0, 5.0, 5));
1261 let b = linspace((2.0, 10.0, 5));
1262 let b_ptr = b.raw().as_ptr();
1263 let c = &a + b;
1264 let c_ptr = c.raw().as_ptr();
1265 let c_ref = vec![3., 6., 9., 12., 15.];
1266 assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1267 assert_eq!(b_ptr, c_ptr);
1268 let a = linspace((1.0, 10.0, 10)).into_shape_assume_contig([5, 2]);
1270 let b = linspace((2.0, 10.0, 5));
1271 let b_ptr = b.raw().as_ptr();
1272 let c = &a + b;
1273 let c_ptr = c.raw().as_ptr();
1274 let c_ref = vec![3., 6., 9., 12., 15., 8., 11., 14., 17., 20.];
1275 assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1276 assert_ne!(b_ptr, c_ptr);
1277 let a = linspace((1.0, 5.0, 5));
1279 let b = linspace((2.0, 10.0, 5));
1280 let a_ptr = a.raw().as_ptr();
1281 let c = a + b;
1282 let c_ptr = c.raw().as_ptr();
1283 let c_ref = vec![3., 6., 9., 12., 15.];
1284 assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1285 assert_eq!(a_ptr, c_ptr);
1286 }
1287
1288 #[test]
1289 fn test_sub_consume() {
1290 let a = linspace((1.0, 5.0, 5));
1292 let b = linspace((2.0, 10.0, 5));
1293 let b_ptr = b.raw().as_ptr();
1294 let c = &a - b;
1295 let c_ptr = c.raw().as_ptr();
1296 let c_ref = vec![-1., -2., -3., -4., -5.].into();
1297 assert!(allclose_f64(&c, &c_ref));
1298 assert_eq!(b_ptr, c_ptr);
1299 let a = linspace((1.0, 5.0, 5));
1301 let b = linspace((2.0, 10.0, 5));
1302 let a_ptr = a.raw().as_ptr();
1303 let c = a - b.view();
1304 let c_ptr = c.raw().as_ptr();
1305 let c_ref = vec![-1., -2., -3., -4., -5.].into();
1306 assert!(allclose_f64(&c, &c_ref));
1307 assert_eq!(a_ptr, c_ptr);
1308 let a = linspace((1.0, 5.0, 5));
1310 let b = linspace((2.0, 10.0, 5));
1311 let a_ptr = a.raw().as_ptr();
1312 let c = a - b;
1313 let c_ptr = c.raw().as_ptr();
1314 let c_ref = vec![-1., -2., -3., -4., -5.].into();
1315 assert!(allclose_f64(&c, &c_ref));
1316 assert_eq!(a_ptr, c_ptr);
1317 }
1318}
1319
1320#[cfg(test)]
1321mod test_with_output {
1322 use super::*;
1323
1324 #[test]
1325 fn test_op_binary_with_output() {
1326 #[cfg(not(feature = "col_major"))]
1327 {
1328 let a = linspace((1.0, 10.0, 10)).into_shape_assume_contig([2, 5]);
1329 let b = linspace((2.0, 10.0, 5)).into_layout([5].c());
1330 let mut c = linspace((1.0, 10.0, 10)).into_shape_assume_contig([2, 5]);
1331 let c_view = c.view_mut();
1332 add_with_output(&a, b, c_view);
1333 println!("{c:?}");
1334 }
1335 #[cfg(feature = "col_major")]
1336 {
1337 let a = linspace((1.0, 10.0, 10)).into_shape_assume_contig([5, 2]);
1338 let b = linspace((2.0, 10.0, 5)).into_layout([5].c());
1339 let mut c = linspace((1.0, 10.0, 10)).into_shape_assume_contig([5, 2]);
1340 let c_view = c.view_mut();
1341 add_with_output(&a, b, c_view);
1342 println!("{c:?}");
1343 }
1344 }
1345}
1346
1347#[cfg(test)]
1348mod tests_with_scalar {
1349 use super::*;
1350
1351 #[test]
1352 fn test_add() {
1353 let a = linspace((1.0, 5.0, 5));
1355 let b = 1;
1356 let c = b - &a;
1357 let c_ref = vec![0., -1., -2., -3., -4.].into();
1358 assert!(allclose_f64(&c, &c_ref));
1359
1360 let a = linspace((1.0, 5.0, 5));
1362 let b = 1;
1363 let c = &a - b;
1364 let c_ref = vec![0., 1., 2., 3., 4.].into();
1365 assert!(allclose_f64(&c, &c_ref));
1366
1367 let a = linspace((1.0, 5.0, 5));
1369 let a_ptr = a.raw().as_ptr();
1370 let b = 2;
1371 let c: Tensor<_> = -b * a;
1372 let c_ref = vec![-2., -4., -6., -8., -10.].into();
1373 assert!(allclose_f64(&c, &c_ref));
1374 let c_ptr = c.raw().as_ptr();
1375 assert_eq!(a_ptr, c_ptr);
1376 }
1377
1378 #[test]
1379 fn test_scalar_consequent() {
1380 let a = linspace((1.0, 5.0, 5));
1381 let mut c = linspace((1.0, 5.0, 5));
1382 let b = a * 2;
1386 *&mut c.i_mut(1) += b.i(1);
1387 println!("{c:?}");
1388 }
1389
1390 #[test]
1391 fn test_cow() {
1392 let a = linspace((1.0, 24.0, 24)).into_shape((2, 3, 4));
1393 let a_cow_view = a.reshape((2, 3, 4));
1394 let a_cow_owned = a.view().into_swapaxes(-1, -2).change_shape((2, 3, 4));
1395 let ptr_a_cow_owned = a_cow_owned.raw().as_ptr();
1396 assert!(!a_cow_view.is_owned());
1397 assert!(a_cow_owned.is_owned());
1398
1399 let b = a.reshape((2, 3, 4)) + a_cow_view;
1400 let ptr_b = b.raw().as_ptr();
1401 println!("{b:?}");
1402 assert_ne!(ptr_a_cow_owned, ptr_b);
1403
1404 let b = a.reshape((2, 3, 4)) + a_cow_owned;
1405 let ptr_b = b.raw().as_ptr();
1406 println!("{b:?}");
1407 assert_eq!(ptr_a_cow_owned, ptr_b);
1408
1409 let b = a.reshape((2, 3, 4)) * 2.0;
1410 println!("{b:?}");
1411 }
1412}
1413
1414