rstsr_core/tensor/manuplication/broadcast.rs
1use crate::prelude_dev::*;
2
3/* #region broadcast_arrays */
4
5/// Broadcasts any number of arrays against each other.
6///
7/// <div class="warning">
8///
9/// **Row/Column Major Notice**
10///
11/// This function behaves differently on default orders ([`RowMajor`] and [`ColMajor`]) of device.
12///
13/// </div>
14///
15/// # Parameters
16///
17/// - `tensors`: [`Vec<TensorAny<R, T, B, IxD>>`](TensorAny)
18///
19/// - The tensors to be broadcasted.
20/// - All tensors must be on the same device, and share the same ownerships.
21/// - This function takes ownership of the input tensors. If you want to obtain broadcasted views,
22/// you need to create a new vector of views first.
23/// - This function only accepts dynamic shape tensors ([`IxD`]).
24///
25/// # Returns
26///
27/// - [`Vec<TensorAny<R, T, B, IxD>>`](TensorAny)
28///
29/// - A vector of broadcasted tensors. Each tensor has the same shape after broadcasting.
30/// - The ownership of the underlying data is moved from the input tensors to the output tensors.
31/// - The tensors are typically not contiguous (with zero strides at the broadcasted axes).
32/// Writing values to broadcasted tensors is dangerous, but RSTSR will generally not panic on
33/// this behavior. Perform [`to_contig`] afterwards if requires owned contiguous tensors.
34///
35/// # Examples
36///
37/// The following example demonstrates how to use `broadcast_arrays` to broadcast two tensors:
38///
39/// ```rust
40/// use rstsr::prelude::*;
41/// let mut device = DeviceCpu::default();
42/// device.set_default_order(RowMajor);
43///
44/// let a = rt::asarray((vec![1, 2, 3], &device)).into_shape([3]);
45/// let b = rt::asarray((vec![4, 5], &device)).into_shape([2, 1]);
46///
47/// let result = rt::broadcast_arrays(vec![a, b]);
48/// let expected_a = rt::tensor_from_nested!(
49/// [[1, 2, 3],
50/// [1, 2, 3]],
51/// &device);
52/// let expected_b = rt::tensor_from_nested!(
53/// [[4, 4, 4],
54/// [5, 5, 5]],
55/// &device);
56/// assert!(rt::allclose!(&result[0], &expected_a));
57/// assert!(rt::allclose!(&result[1], &expected_b));
58/// ```
59///
60/// Please note that the above code only works in [RowMajor].
61///
62/// For [ColMajor] order, the broadcasting will fail, because the broadcasting rules are applied
63/// differently, shapes are incompatible. You need to make the following changes to let [ColMajor]
64/// case work:
65///
66/// ```rust
67/// # use rstsr::prelude::*;
68/// let mut device = DeviceCpu::default();
69/// device.set_default_order(ColMajor);
70/// // Note shape of `a` changed from [3] to [1, 3]
71/// let a = rt::asarray((vec![1, 2, 3], &device)).into_shape([1, 3]);
72/// let b = rt::asarray((vec![4, 5], &device)).into_shape([2, 1]);
73/// #
74/// # let result = rt::broadcast_arrays(vec![a, b]);
75/// # let expected_a = rt::tensor_from_nested!(
76/// # [[1, 2, 3],
77/// # [1, 2, 3]],
78/// # &device);
79/// # let expected_b = rt::tensor_from_nested!(
80/// # [[4, 4, 4],
81/// # [5, 5, 5]],
82/// # &device);
83/// # assert!(rt::allclose!(&result[0], &expected_a));
84/// # assert!(rt::allclose!(&result[1], &expected_b));
85/// ```
86///
87/// # Panics
88///
89/// - Incompatible shapes to be broadcasted.
90/// - Tensors are on different devices.
91///
92/// # See also
93///
94/// ## Similar function from other crates/libraries
95///
96/// - Python Array API standard: [`broadcast_arrays`](https://data-apis.org/array-api/2024.12/API_specification/generated/array_api.broadcast_arrays.html)
97/// - NumPy: [`numpy.broadcast_arrays`](https://numpy.org/doc/stable/reference/generated/numpy.broadcast_arrays.html)
98///
99/// ## Related functions in RSTSR
100///
101/// - [`to_broadcast`]: Broadcasts a single array to a specified shape.
102///
103/// ## Variants of this function
104///
105/// - [`broadcast_arrays_f`]: Fallible version, actual implementation.
106pub fn broadcast_arrays<R, T, B>(tensors: Vec<TensorAny<R, T, B, IxD>>) -> Vec<TensorAny<R, T, B, IxD>>
107where
108 R: DataAPI<Data = B::Raw>,
109 B: DeviceAPI<T>,
110{
111 broadcast_arrays_f(tensors).rstsr_unwrap()
112}
113
114/// Broadcasts any number of arrays against each other.
115///
116/// # See also
117///
118/// Refer to [`broadcast_arrays`] for detailed documentation.
119pub fn broadcast_arrays_f<R, T, B>(tensors: Vec<TensorAny<R, T, B, IxD>>) -> Result<Vec<TensorAny<R, T, B, IxD>>>
120where
121 R: DataAPI<Data = B::Raw>,
122 B: DeviceAPI<T>,
123{
124 // fast return if there is only zero/one tensor
125 if tensors.len() <= 1 {
126 return Ok(tensors);
127 }
128 let device_b = tensors[0].device().clone();
129 let default_order = device_b.default_order();
130 let mut shape_b = tensors[0].shape().clone();
131 for tensor in tensors.iter().skip(1) {
132 rstsr_assert!(device_b.same_device(tensor.device()), DeviceMismatch)?;
133 let shape = tensor.shape();
134 let (shape, _, _) = broadcast_shape(shape, &shape_b, default_order)?;
135 shape_b = shape;
136 }
137 let mut tensors_new = Vec::with_capacity(tensors.len());
138 for tensor in tensors {
139 let tensor = into_broadcast_f(tensor, shape_b.clone())?;
140 tensors_new.push(tensor);
141 }
142 return Ok(tensors_new);
143}
144
145/* #endregion */
146
147/* #region broadcast_to */
148
149/// Broadcasts an array to a specified shape.
150///
151/// # See also
152///
153/// Refer to [`to_broadcast`] and [`into_broadcast`] for detailed documentation.
154pub fn into_broadcast_f<R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, shape: D2) -> Result<TensorAny<R, T, B, D2>>
155where
156 R: DataAPI<Data = B::Raw>,
157 B: DeviceAPI<T>,
158 D: DimAPI + DimMaxAPI<D2, Max = D2>,
159 D2: DimAPI,
160{
161 let shape1 = tensor.shape();
162 let shape2 = &shape;
163 let default_order = tensor.device().default_order();
164 let (shape, tp1, _) = broadcast_shape(shape1, shape2, default_order)?;
165 let (storage, layout) = tensor.into_raw_parts();
166 let layout = update_layout_by_shape(&layout, &shape, &tp1, default_order)?;
167 unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
168}
169
170/// Broadcasts an array to a specified shape.
171///
172/// <div class="warning">
173///
174/// **Row/Column Major Notice**
175///
176/// This function behaves differently on default orders ([`RowMajor`] and [`ColMajor`]) of device.
177///
178/// </div>
179///
180/// # Parameters
181///
182/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
183///
184/// - The input tensor to be broadcasted.
185///
186/// - `shape`: impl [`DimAPI`]
187///
188/// - The shape of the desired output tensor after broadcasting.
189/// - Please note [`IxD`] (`Vec<usize>`) and [`Ix<N>`] (`[usize; N]`) behaves differently here.
190/// [`IxD`] will give dynamic shape tensor, while [`Ix<N>`] will give static shape tensor.
191///
192/// # Returns
193///
194/// - [`TensorView<'_, T, B, D2>`]
195///
196/// - A readonly view on the original tensor with the given shape. It is typically not contiguous
197/// (perform [`to_contig`] afterwards if you require contiguous owned tensors).
198/// - Furthermore, more than one element of a broadcasted tensor may refer to a single memory
199/// location (zero strides at the broadcasted axes). Writing values to broadcasted tensors is
200/// dangerous, but RSTSR will generally not panic on this behavior.
201/// - If you want to convert the tensor itself (taking the ownership instead of returning view),
202/// use [`into_broadcast`] instead.
203///
204/// # Examples
205///
206/// The following example demonstrates how to use `to_broadcast` to broadcast a 1-D tensor
207/// (3-element vector) to a 2-D tensor (2x3 matrix) by repeating the original data along a new axis.
208///
209/// ```rust
210/// use rstsr::prelude::*;
211/// let mut device = DeviceCpu::default();
212/// device.set_default_order(RowMajor);
213///
214/// let a = rt::tensor_from_nested!([1, 2, 3], &device);
215///
216/// // broadcast (3, ) -> (2, 3) in row-major:
217/// let result = a.to_broadcast(vec![2, 3]);
218/// let expected = rt::tensor_from_nested!(
219/// [[1, 2, 3],
220/// [1, 2, 3]],
221/// &device);
222/// assert!(rt::allclose!(&result, &expected));
223/// ```
224///
225/// Please note the above example is only working in RowMajor order. In ColMajor order, the
226/// broadcasting will be done along the other axis:
227///
228/// ```rust
229/// use rstsr::prelude::*;
230/// let mut device = DeviceCpu::default();
231/// device.set_default_order(ColMajor);
232///
233/// let a = rt::tensor_from_nested!([1, 2, 3], &device);
234/// // in col-major, broadcast (3, ) -> (2, 3) will fail:
235/// let result = a.to_broadcast_f(vec![2, 3]);
236/// assert!(result.is_err());
237///
238/// // broadcast (3, ) -> (3, 2) in col-major:
239/// let result = a.to_broadcast(vec![3, 2]);
240/// let expected = rt::tensor_from_nested!(
241/// [[1, 1],
242/// [2, 2],
243/// [3, 3]],
244/// &device);
245/// assert!(rt::allclose!(&result, &expected));
246/// ```
247///
248/// # Panics
249///
250/// - Incompatible shapes to be broadcasted.
251///
252/// # Elaborated examples
253///
254/// ## Broadcasting behavior (in row-major)
255///
256/// This example does not directly call this function `to_broadcast`, but demonstrates the
257/// broadcasting behavior.
258///
259/// ```rust
260/// use rstsr::prelude::*;
261/// let mut device = DeviceCpu::default();
262/// device.set_default_order(RowMajor);
263///
264/// // A (4d tensor): 8 x 1 x 6 x 1
265/// // B (3d tensor): 7 x 1 x 5
266/// // ----------------------------------
267/// // Result (4d tensor): 8 x 7 x 6 x 5
268/// let a = rt::arange((48, &device)).into_shape([8, 1, 6, 1]);
269/// let b = rt::arange((35, &device)).into_shape([7, 1, 5]);
270/// let result = &a + &b;
271/// assert_eq!(result.shape(), &[8, 7, 6, 5]);
272///
273/// // A (2d tensor): 5 x 4
274/// // B (1d tensor): 1
275/// // --------------------------
276/// // Result (2d tensor): 5 x 4
277/// let a = rt::arange((20, &device)).into_shape([5, 4]);
278/// let b = rt::arange((1, &device)).into_shape([1]);
279/// let result = &a + &b;
280/// assert_eq!(result.shape(), &[5, 4]);
281///
282/// // A (2d tensor): 5 x 4
283/// // B (1d tensor): 4
284/// // --------------------------
285/// // Result (2d tensor): 5 x 4
286/// let a = rt::arange((20, &device)).into_shape([5, 4]);
287/// let b = rt::arange((4, &device)).into_shape([4]);
288/// let result = &a + &b;
289/// assert_eq!(result.shape(), &[5, 4]);
290///
291/// // A (3d tensor): 15 x 3 x 5
292/// // B (3d tensor): 15 x 1 x 5
293/// // -------------------------------
294/// // Result (3d tensor): 15 x 3 x 5
295/// let a = rt::arange((225, &device)).into_shape([15, 3, 5]);
296/// let b = rt::arange((75, &device)).into_shape([15, 1, 5]);
297/// let result = &a + &b;
298/// assert_eq!(result.shape(), &[15, 3, 5]);
299///
300/// // A (3d tensor): 15 x 3 x 5
301/// // B (2d tensor): 3 x 5
302/// // -------------------------------
303/// // Result (3d tensor): 15 x 3 x 5
304/// let a = rt::arange((225, &device)).into_shape([15, 3, 5]);
305/// let b = rt::arange((15, &device)).into_shape([3, 5]);
306/// let result = &a + &b;
307/// assert_eq!(result.shape(), &[15, 3, 5]);
308///
309/// // A (3d tensor): 15 x 3 x 5
310/// // B (2d tensor): 3 x 1
311/// // -------------------------------
312/// // Result (3d tensor): 15 x 3 x 5
313/// let a = rt::arange((225, &device)).into_shape([15, 3, 5]);
314/// let b = rt::arange((3, &device)).into_shape([3, 1]);
315/// let result = &a + &b;
316/// assert_eq!(result.shape(), &[15, 3, 5]);
317/// ```
318///
319/// ## Broadcasting behavior (in col-major)
320///
321/// This example does not directly call this function `to_broadcast`, but demonstrates the
322/// broadcasting behavior.
323///
324/// ```rust
325/// use rstsr::prelude::*;
326/// let mut device = DeviceCpu::default();
327/// device.set_default_order(ColMajor);
328///
329/// // A (4d tensor): 1 x 6 x 1 x 8
330/// // B (3d tensor): 5 x 1 x 7
331/// // ----------------------------------
332/// // Result (4d tensor): 5 x 6 x 7 x 8
333/// let a = rt::arange((48, &device)).into_shape([1, 6, 1, 8]);
334/// let b = rt::arange((35, &device)).into_shape([5, 1, 7]);
335/// let result = &a + &b;
336/// assert_eq!(result.shape(), &[5, 6, 7, 8]);
337///
338/// // A (2d tensor): 4 x 5
339/// // B (1d tensor): 1
340/// // --------------------------
341/// // Result (2d tensor): 4 x 5
342/// let a = rt::arange((20, &device)).into_shape([4, 5]);
343/// let b = rt::arange((1, &device)).into_shape([1]);
344/// let result = &a + &b;
345/// assert_eq!(result.shape(), &[4, 5]);
346///
347/// // A (2d tensor): 4 x 5
348/// // B (1d tensor): 4
349/// // --------------------------
350/// // Result (2d tensor): 4 x 5
351/// let a = rt::arange((20, &device)).into_shape([4, 5]);
352/// let b = rt::arange((4, &device)).into_shape([4]);
353/// let result = &a + &b;
354/// assert_eq!(result.shape(), &[4, 5]);
355///
356/// // A (3d tensor): 5 x 3 x 15
357/// // B (3d tensor): 5 x 1 x 15
358/// // -------------------------------
359/// // Result (3d tensor): 5 x 3 x 15
360/// let a = rt::arange((225, &device)).into_shape([5, 3, 15]);
361/// let b = rt::arange((75, &device)).into_shape([5, 1, 15]);
362/// let result = &a + &b;
363/// assert_eq!(result.shape(), &[5, 3, 15]);
364///
365/// // A (3d tensor): 5 x 3 x 15
366/// // B (2d tensor): 5 x 3
367/// // -------------------------------
368/// // Result (3d tensor): 5 x 3 x 15
369/// let a = rt::arange((225, &device)).into_shape([5, 3, 15]);
370/// let b = rt::arange((15, &device)).into_shape([5, 3]);
371/// let result = &a + &b;
372/// assert_eq!(result.shape(), &[5, 3, 15]);
373///
374/// // A (3d tensor): 5 x 3 x 15
375/// // B (2d tensor): 1 x 3
376/// // -------------------------------
377/// // Result (3d tensor): 5 x 3 x 15
378/// let a = rt::arange((225, &device)).into_shape([5, 3, 15]);
379/// let b = rt::arange((3, &device)).into_shape([1, 3]);
380/// let result = &a + &b;
381/// assert_eq!(result.shape(), &[5, 3, 15]);
382/// ```
383///
384/// # See also
385///
386/// ## Detailed broadcasting rules
387///
388/// - Python Array API standard: [Broadcasting rules](https://data-apis.org/array-api/latest/API_specification/broadcasting.html)
389/// - NumPy: [Broadcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html)
390///
391/// ## Similar function from other crates/libraries
392///
393/// - Python Array API standard: [`broadcast_to`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.broadcast_to.html)
394/// - NumPy: [`numpy.broadcast_to`](https://numpy.org/doc/stable/reference/generated/numpy.broadcast_to.html)
395/// - ndarray: [`ndarray::broadcast`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html#method.broadcast)
396///
397/// ## Related functions in RSTSR
398///
399/// - [`broadcast_arrays`]: Broadcasts any number of arrays against each other.
400///
401/// ## Variants of this function
402///
403/// - [`to_broadcast`]: Standard version.
404/// - [`to_broadcast_f`]: Fallible version.
405/// - [`into_broadcast`]: Consuming version that takes ownership of the input tensor.
406/// - [`into_broadcast_f`]: Consuming and fallible version, actual implementation.
407/// - [`broadcast_to`]: Alias for `to_broadcast` (name of Python Array API standard).
408/// - Associated methods on [`TensorAny`]:
409///
410/// - [`TensorAny::to_broadcast`]
411/// - [`TensorAny::to_broadcast_f`]
412/// - [`TensorAny::into_broadcast`]
413/// - [`TensorAny::into_broadcast_f`]
414/// - [`TensorAny::broadcast_to`]
415pub fn to_broadcast<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>, shape: D2) -> TensorView<'_, T, B, D2>
416where
417 D: DimAPI + DimMaxAPI<D2, Max = D2>,
418 D2: DimAPI,
419 R: DataAPI<Data = B::Raw>,
420 B: DeviceAPI<T>,
421{
422 into_broadcast_f(tensor.view(), shape).rstsr_unwrap()
423}
424
425/// Broadcasts an array to a specified shape.
426///
427/// # See also
428///
429/// Refer to [`to_broadcast`] or [`into_broadcast`] for detailed documentation.
430pub fn broadcast_to<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>, shape: D2) -> TensorView<'_, T, B, D2>
431where
432 D: DimAPI + DimMaxAPI<D2, Max = D2>,
433 D2: DimAPI,
434 R: DataAPI<Data = B::Raw>,
435 B: DeviceAPI<T>,
436{
437 into_broadcast_f(tensor.view(), shape).rstsr_unwrap()
438}
439
440/// Broadcasts an array to a specified shape.
441///
442/// # See also
443///
444/// Refer to [`to_broadcast`] or [`into_broadcast`] for detailed documentation.
445pub fn to_broadcast_f<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>, shape: D2) -> Result<TensorView<'_, T, B, D2>>
446where
447 D: DimAPI + DimMaxAPI<D2, Max = D2>,
448 D2: DimAPI,
449 R: DataAPI<Data = B::Raw>,
450 B: DeviceAPI<T>,
451{
452 into_broadcast_f(tensor.view(), shape)
453}
454
455/// Broadcasts an array to a specified shape.
456///
457/// <div class="warning">
458///
459/// **Row/Column Major Notice**
460///
461/// This function behaves differently on default orders ([`RowMajor`] and [`ColMajor`]) of device.
462///
463/// </div>
464///
465/// # Parameters
466///
467/// - `tensor`: [`TensorAny<R, T, B, D>`]
468///
469/// - The input tensor to be broadcasted.
470/// - Ownership of input tensor is taken.
471///
472/// - `shape`: impl [`DimAPI`]
473///
474/// - The shape of the desired output tensor after broadcasting.
475/// - Please note [`IxD`] (`Vec<usize>`) and [`Ix<N>`] (`[usize; N]`) behaves differently here.
476/// [`IxD`] will give dynamic shape tensor, while [`Ix<N>`] will give static shape tensor.
477///
478///
479/// # Returns
480///
481/// - [`TensorAny<R, T, B, D2>`]
482///
483/// - The tensor with the given shape. It is typically not contiguous (perform [`to_contig`]
484/// afterwards if requires a contiguous owned tensor).
485/// - Furthermore, more than one element of a broadcasted tensor may refer to a single memory
486/// location (zero strides at the broadcasted axes).
487/// - Ownership of the returned tensor is transferred from the input tensor. Only the layout is
488/// modified; the underlying data remains unchanged.
489///
490/// # See also
491///
492/// Refer to [`to_broadcast`] for more detailed documentation.
493pub fn into_broadcast<R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, shape: D2) -> TensorAny<R, T, B, D2>
494where
495 R: DataAPI<Data = B::Raw>,
496 B: DeviceAPI<T>,
497 D: DimAPI + DimMaxAPI<D2, Max = D2>,
498 D2: DimAPI,
499{
500 into_broadcast_f(tensor, shape).rstsr_unwrap()
501}
502
503impl<R, T, B, D> TensorAny<R, T, B, D>
504where
505 R: DataAPI<Data = B::Raw>,
506 B: DeviceAPI<T>,
507 D: DimAPI,
508{
509 /// Broadcasts an array to a specified shape.
510 ///
511 /// # See also
512 ///
513 /// Refer to [`to_broadcast`] or [`into_broadcast`] for detailed documentation.
514 pub fn to_broadcast<D2>(&self, shape: D2) -> TensorView<'_, T, B, D2>
515 where
516 D2: DimAPI,
517 D: DimMaxAPI<D2, Max = D2>,
518 {
519 to_broadcast(self, shape)
520 }
521
522 /// Broadcasts an array to a specified shape.
523 ///
524 /// # See also
525 ///
526 /// Refer to [`to_broadcast`] or [`into_broadcast`] for detailed documentation.
527 pub fn broadcast_to<D2>(&self, shape: D2) -> TensorView<'_, T, B, D2>
528 where
529 D2: DimAPI,
530 D: DimMaxAPI<D2, Max = D2>,
531 {
532 broadcast_to(self, shape)
533 }
534
535 /// Broadcasts an array to a specified shape.
536 ///
537 /// # See also
538 ///
539 /// Refer to [`to_broadcast`] or [`into_broadcast`] for detailed documentation.
540 pub fn to_broadcast_f<D2>(&self, shape: D2) -> Result<TensorView<'_, T, B, D2>>
541 where
542 D2: DimAPI,
543 D: DimMaxAPI<D2, Max = D2>,
544 {
545 to_broadcast_f(self, shape)
546 }
547
548 /// Broadcasts an array to a specified shape.
549 ///
550 /// # See also
551 ///
552 /// Refer to [`to_broadcast`] or [`into_broadcast`] for detailed documentation.
553 pub fn into_broadcast<D2>(self, shape: D2) -> TensorAny<R, T, B, D2>
554 where
555 D2: DimAPI,
556 D: DimMaxAPI<D2, Max = D2>,
557 {
558 into_broadcast(self, shape)
559 }
560
561 /// Broadcasts an array to a specified shape.
562 ///
563 /// # See also
564 ///
565 /// Refer to [`to_broadcast`] or [`into_broadcast`] for detailed documentation.
566 pub fn into_broadcast_f<D2>(self, shape: D2) -> Result<TensorAny<R, T, B, D2>>
567 where
568 D2: DimAPI,
569 D: DimMaxAPI<D2, Max = D2>,
570 {
571 into_broadcast_f(self, shape)
572 }
573}
574
575/* #endregion */
576
577#[cfg(test)]
578mod tests {
579
580 #[test]
581 #[rustfmt::skip]
582 fn doc_broadcast_to() {
583 use rstsr::prelude::*;
584 let mut device = DeviceCpu::default();
585 device.set_default_order(RowMajor);
586
587 let a = rt::tensor_from_nested!([1, 2, 3], &device);
588
589 // broadcast (3, ) -> (2, 3) in row-major:
590 let result = a.to_broadcast(vec![2, 3]);
591 let expected = rt::tensor_from_nested!(
592 [[1, 2, 3],
593 [1, 2, 3]],
594 &device);
595 assert!(rt::allclose!(&result, &expected));
596 }
597
598 #[test]
599 #[rustfmt::skip]
600 fn doc_broadcast_to_col_major() {
601 use rstsr::prelude::*;
602 let mut device = DeviceCpu::default();
603 device.set_default_order(ColMajor);
604
605 let a = rt::tensor_from_nested!([1, 2, 3], &device);
606 // in col-major, broadcast (3, ) -> (2, 3) will fail:
607 let result = a.to_broadcast_f(vec![2, 3]);
608 assert!(result.is_err());
609
610 // broadcast (3, ) -> (3, 2) in col-major:
611 let result = a.to_broadcast(vec![3, 2]);
612 let expected = rt::tensor_from_nested!(
613 [[1, 1],
614 [2, 2],
615 [3, 3]],
616 &device);
617 assert!(rt::allclose!(&result, &expected));
618 }
619
620 #[test]
621 fn doc_broadcast_to_elaborated_row_major() {
622 use rstsr::prelude::*;
623 let mut device = DeviceCpu::default();
624 device.set_default_order(RowMajor);
625
626 // A (4d tensor): 8 x 1 x 6 x 1
627 // B (3d tensor): 7 x 1 x 5
628 // ----------------------------------
629 // Result (4d tensor): 8 x 7 x 6 x 5
630 let a = rt::arange((48, &device)).into_shape([8, 1, 6, 1]);
631 let b = rt::arange((35, &device)).into_shape([7, 1, 5]);
632 let result = &a + &b;
633 assert_eq!(result.shape(), &[8, 7, 6, 5]);
634
635 // A (2d tensor): 5 x 4
636 // B (1d tensor): 1
637 // --------------------------
638 // Result (2d tensor): 5 x 4
639 let a = rt::arange((20, &device)).into_shape([5, 4]);
640 let b = rt::arange((1, &device)).into_shape([1]);
641 let result = &a + &b;
642 assert_eq!(result.shape(), &[5, 4]);
643
644 // A (2d tensor): 5 x 4
645 // B (1d tensor): 4
646 // --------------------------
647 // Result (2d tensor): 5 x 4
648 let a = rt::arange((20, &device)).into_shape([5, 4]);
649 let b = rt::arange((4, &device)).into_shape([4]);
650 let result = &a + &b;
651 assert_eq!(result.shape(), &[5, 4]);
652
653 // A (3d tensor): 15 x 3 x 5
654 // B (3d tensor): 15 x 1 x 5
655 // -------------------------------
656 // Result (3d tensor): 15 x 3 x 5
657 let a = rt::arange((225, &device)).into_shape([15, 3, 5]);
658 let b = rt::arange((75, &device)).into_shape([15, 1, 5]);
659 let result = &a + &b;
660 assert_eq!(result.shape(), &[15, 3, 5]);
661
662 // A (3d tensor): 15 x 3 x 5
663 // B (2d tensor): 3 x 5
664 // -------------------------------
665 // Result (3d tensor): 15 x 3 x 5
666 let a = rt::arange((225, &device)).into_shape([15, 3, 5]);
667 let b = rt::arange((15, &device)).into_shape([3, 5]);
668 let result = &a + &b;
669 assert_eq!(result.shape(), &[15, 3, 5]);
670
671 // A (3d tensor): 15 x 3 x 5
672 // B (2d tensor): 3 x 1
673 // -------------------------------
674 // Result (3d tensor): 15 x 3 x 5
675 let a = rt::arange((225, &device)).into_shape([15, 3, 5]);
676 let b = rt::arange((3, &device)).into_shape([3, 1]);
677 let result = &a + &b;
678 assert_eq!(result.shape(), &[15, 3, 5]);
679 }
680
681 #[test]
682 fn doc_broadcast_to_elaborated_col_major() {
683 use rstsr::prelude::*;
684 let mut device = DeviceCpu::default();
685 device.set_default_order(ColMajor);
686
687 // A (4d tensor): 1 x 6 x 1 x 8
688 // B (3d tensor): 5 x 1 x 7
689 // ----------------------------------
690 // Result (4d tensor): 5 x 6 x 7 x 8
691 let a = rt::arange((48, &device)).into_shape([1, 6, 1, 8]);
692 let b = rt::arange((35, &device)).into_shape([5, 1, 7]);
693 let result = &a + &b;
694 assert_eq!(result.shape(), &[5, 6, 7, 8]);
695
696 // A (2d tensor): 4 x 5
697 // B (1d tensor): 1
698 // --------------------------
699 // Result (2d tensor): 4 x 5
700 let a = rt::arange((20, &device)).into_shape([4, 5]);
701 let b = rt::arange((1, &device)).into_shape([1]);
702 let result = &a + &b;
703 assert_eq!(result.shape(), &[4, 5]);
704
705 // A (2d tensor): 4 x 5
706 // B (1d tensor): 4
707 // --------------------------
708 // Result (2d tensor): 4 x 5
709 let a = rt::arange((20, &device)).into_shape([4, 5]);
710 let b = rt::arange((4, &device)).into_shape([4]);
711 let result = &a + &b;
712 assert_eq!(result.shape(), &[4, 5]);
713
714 // A (3d tensor): 5 x 3 x 15
715 // B (3d tensor): 5 x 1 x 15
716 // -------------------------------
717 // Result (3d tensor): 5 x 3 x 15
718 let a = rt::arange((225, &device)).into_shape([5, 3, 15]);
719 let b = rt::arange((75, &device)).into_shape([5, 1, 15]);
720 let result = &a + &b;
721 assert_eq!(result.shape(), &[5, 3, 15]);
722
723 // A (3d tensor): 5 x 3 x 15
724 // B (2d tensor): 5 x 3
725 // -------------------------------
726 // Result (3d tensor): 5 x 3 x 15
727 let a = rt::arange((225, &device)).into_shape([5, 3, 15]);
728 let b = rt::arange((15, &device)).into_shape([5, 3]);
729 let result = &a + &b;
730 assert_eq!(result.shape(), &[5, 3, 15]);
731
732 // A (3d tensor): 5 x 3 x 15
733 // B (2d tensor): 1 x 3
734 // -------------------------------
735 // Result (3d tensor): 5 x 3 x 15
736 let a = rt::arange((225, &device)).into_shape([5, 3, 15]);
737 let b = rt::arange((3, &device)).into_shape([1, 3]);
738 let result = &a + &b;
739 assert_eq!(result.shape(), &[5, 3, 15]);
740 }
741
742 #[test]
743 #[rustfmt::skip]
744 fn doc_broadcast_arrays_row_major() {
745 use rstsr::prelude::*;
746 let mut device = DeviceCpu::default();
747 device.set_default_order(RowMajor);
748
749 let a = rt::asarray((vec![1, 2, 3], &device)).into_shape([3]);
750 let b = rt::asarray((vec![4, 5], &device)).into_shape([2, 1]);
751
752 let result = rt::broadcast_arrays(vec![a, b]);
753 let expected_a = rt::tensor_from_nested!(
754 [[1, 2, 3],
755 [1, 2, 3]],
756 &device);
757 let expected_b = rt::tensor_from_nested!(
758 [[4, 4, 4],
759 [5, 5, 5]],
760 &device);
761 assert!(rt::allclose!(&result[0], &expected_a));
762 assert!(rt::allclose!(&result[1], &expected_b));
763 }
764
765 #[test]
766 #[rustfmt::skip]
767 fn doc_broadcast_arrays_col_major() {
768 use rstsr::prelude::*;
769 let mut device = DeviceCpu::default();
770 device.set_default_order(ColMajor);
771
772 let a = rt::asarray((vec![1, 2, 3], &device)).into_shape([1, 3]);
773 let b = rt::asarray((vec![4, 5], &device)).into_shape([2, 1]);
774
775 let result = rt::broadcast_arrays(vec![a, b]);
776 let expected_a = rt::tensor_from_nested!(
777 [[1, 2, 3],
778 [1, 2, 3]],
779 &device);
780 let expected_b = rt::tensor_from_nested!(
781 [[4, 4, 4],
782 [5, 5, 5]],
783 &device);
784 assert!(rt::allclose!(&result[0], &expected_a));
785 assert!(rt::allclose!(&result[1], &expected_b));
786 }
787}