rstsr_core/tensor/manuplication/
broadcast.rs

1use crate::prelude_dev::*;
2
3/* #region broadcast_arrays */
4
5/// Broadcasts any number of arrays against each other.
6///
7/// <div class="warning">
8///
9/// **Row/Column Major Notice**
10///
11/// This function behaves differently on default orders ([`RowMajor`] and [`ColMajor`]) of device.
12///
13/// </div>
14///
15/// # Parameters
16///
17/// - `tensors`: [`Vec<TensorAny<R, T, B, IxD>>`](TensorAny)
18///
19///   - The tensors to be broadcasted.
20///   - All tensors must be on the same device, and share the same ownerships.
21///   - This function takes ownership of the input tensors. If you want to obtain broadcasted views,
22///     you need to create a new vector of views first.
23///   - This function only accepts dynamic shape tensors ([`IxD`]).
24///
25/// # Returns
26///
27/// - [`Vec<TensorAny<R, T, B, IxD>>`](TensorAny)
28///
29///   - A vector of broadcasted tensors. Each tensor has the same shape after broadcasting.
30///   - The ownership of the underlying data is moved from the input tensors to the output tensors.
31///   - The tensors are typically not contiguous (with zero strides at the broadcasted axes).
32///     Writing values to broadcasted tensors is dangerous, but RSTSR will generally not panic on
33///     this behavior. Perform [`to_contig`] afterwards if requires owned contiguous tensors.
34///
35/// # Examples
36///
37/// The following example demonstrates how to use `broadcast_arrays` to broadcast two tensors:
38///
39/// ```rust
40/// use rstsr::prelude::*;
41/// let mut device = DeviceCpu::default();
42/// device.set_default_order(RowMajor);
43///
44/// let a = rt::asarray((vec![1, 2, 3], &device)).into_shape([3]);
45/// let b = rt::asarray((vec![4, 5], &device)).into_shape([2, 1]);
46///
47/// let result = rt::broadcast_arrays(vec![a, b]);
48/// let expected_a = rt::tensor_from_nested!(
49///     [[1, 2, 3],
50///      [1, 2, 3]],
51///     &device);
52/// let expected_b = rt::tensor_from_nested!(
53///     [[4, 4, 4],
54///      [5, 5, 5]],
55///     &device);
56/// assert!(rt::allclose!(&result[0], &expected_a));
57/// assert!(rt::allclose!(&result[1], &expected_b));
58/// ```
59///
60/// Please note that the above code only works in [RowMajor].
61///
62/// For [ColMajor] order, the broadcasting will fail, because the broadcasting rules are applied
63/// differently, shapes are incompatible. You need to make the following changes to let [ColMajor]
64/// case work:
65///
66/// ```rust
67/// # use rstsr::prelude::*;
68/// let mut device = DeviceCpu::default();
69/// device.set_default_order(ColMajor);
70/// // Note shape of `a` changed from [3] to [1, 3]
71/// let a = rt::asarray((vec![1, 2, 3], &device)).into_shape([1, 3]);
72/// let b = rt::asarray((vec![4, 5], &device)).into_shape([2, 1]);
73/// #
74/// # let result = rt::broadcast_arrays(vec![a, b]);
75/// # let expected_a = rt::tensor_from_nested!(
76/// #     [[1, 2, 3],
77/// #      [1, 2, 3]],
78/// #     &device);
79/// # let expected_b = rt::tensor_from_nested!(
80/// #     [[4, 4, 4],
81/// #      [5, 5, 5]],
82/// #     &device);
83/// # assert!(rt::allclose!(&result[0], &expected_a));
84/// # assert!(rt::allclose!(&result[1], &expected_b));
85/// ```
86///
87/// # Panics
88///
89/// - Incompatible shapes to be broadcasted.
90/// - Tensors are on different devices.
91///
92/// # See also
93///
94/// ## Similar function from other crates/libraries
95///
96/// - Python Array API standard: [`broadcast_arrays`](https://data-apis.org/array-api/2024.12/API_specification/generated/array_api.broadcast_arrays.html)
97/// - NumPy: [`numpy.broadcast_arrays`](https://numpy.org/doc/stable/reference/generated/numpy.broadcast_arrays.html)
98///
99/// ## Related functions in RSTSR
100///
101/// - [`to_broadcast`]: Broadcasts a single array to a specified shape.
102///
103/// ## Variants of this function
104///
105/// - [`broadcast_arrays_f`]: Fallible version, actual implementation.
106pub fn broadcast_arrays<R, T, B>(tensors: Vec<TensorAny<R, T, B, IxD>>) -> Vec<TensorAny<R, T, B, IxD>>
107where
108    R: DataAPI<Data = B::Raw>,
109    B: DeviceAPI<T>,
110{
111    broadcast_arrays_f(tensors).rstsr_unwrap()
112}
113
114/// Broadcasts any number of arrays against each other.
115///
116/// # See also
117///
118/// Refer to [`broadcast_arrays`] for detailed documentation.
119pub fn broadcast_arrays_f<R, T, B>(tensors: Vec<TensorAny<R, T, B, IxD>>) -> Result<Vec<TensorAny<R, T, B, IxD>>>
120where
121    R: DataAPI<Data = B::Raw>,
122    B: DeviceAPI<T>,
123{
124    // fast return if there is only zero/one tensor
125    if tensors.len() <= 1 {
126        return Ok(tensors);
127    }
128    let device_b = tensors[0].device().clone();
129    let default_order = device_b.default_order();
130    let mut shape_b = tensors[0].shape().clone();
131    for tensor in tensors.iter().skip(1) {
132        rstsr_assert!(device_b.same_device(tensor.device()), DeviceMismatch)?;
133        let shape = tensor.shape();
134        let (shape, _, _) = broadcast_shape(shape, &shape_b, default_order)?;
135        shape_b = shape;
136    }
137    let mut tensors_new = Vec::with_capacity(tensors.len());
138    for tensor in tensors {
139        let tensor = into_broadcast_f(tensor, shape_b.clone())?;
140        tensors_new.push(tensor);
141    }
142    return Ok(tensors_new);
143}
144
145/* #endregion */
146
147/* #region broadcast_to */
148
149/// Broadcasts an array to a specified shape.
150///
151/// # See also
152///
153/// Refer to [`to_broadcast`] and [`into_broadcast`] for detailed documentation.
154pub fn into_broadcast_f<R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, shape: D2) -> Result<TensorAny<R, T, B, D2>>
155where
156    R: DataAPI<Data = B::Raw>,
157    B: DeviceAPI<T>,
158    D: DimAPI + DimMaxAPI<D2, Max = D2>,
159    D2: DimAPI,
160{
161    let shape1 = tensor.shape();
162    let shape2 = &shape;
163    let default_order = tensor.device().default_order();
164    let (shape, tp1, _) = broadcast_shape(shape1, shape2, default_order)?;
165    let (storage, layout) = tensor.into_raw_parts();
166    let layout = update_layout_by_shape(&layout, &shape, &tp1, default_order)?;
167    unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
168}
169
170/// Broadcasts an array to a specified shape.
171///
172/// <div class="warning">
173///
174/// **Row/Column Major Notice**
175///
176/// This function behaves differently on default orders ([`RowMajor`] and [`ColMajor`]) of device.
177///
178/// </div>
179///
180/// # Parameters
181///
182/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
183///
184///   - The input tensor to be broadcasted.
185///
186/// - `shape`: impl [`DimAPI`]
187///
188///   - The shape of the desired output tensor after broadcasting.
189///   - Please note [`IxD`] (`Vec<usize>`) and [`Ix<N>`] (`[usize; N]`) behaves differently here.
190///     [`IxD`] will give dynamic shape tensor, while [`Ix<N>`] will give static shape tensor.
191///
192/// # Returns
193///
194/// - [`TensorView<'_, T, B, D2>`]
195///
196///   - A readonly view on the original tensor with the given shape. It is typically not contiguous
197///     (perform [`to_contig`] afterwards if you require contiguous owned tensors).
198///   - Furthermore, more than one element of a broadcasted tensor may refer to a single memory
199///     location (zero strides at the broadcasted axes). Writing values to broadcasted tensors is
200///     dangerous, but RSTSR will generally not panic on this behavior.
201///   - If you want to convert the tensor itself (taking the ownership instead of returning view),
202///     use [`into_broadcast`] instead.
203///
204/// # Examples
205///
206/// The following example demonstrates how to use `to_broadcast` to broadcast a 1-D tensor
207/// (3-element vector) to a 2-D tensor (2x3 matrix) by repeating the original data along a new axis.
208///
209/// ```rust
210/// use rstsr::prelude::*;
211/// let mut device = DeviceCpu::default();
212/// device.set_default_order(RowMajor);
213///
214/// let a = rt::tensor_from_nested!([1, 2, 3], &device);
215///
216/// // broadcast (3, ) -> (2, 3) in row-major:
217/// let result = a.to_broadcast(vec![2, 3]);
218/// let expected = rt::tensor_from_nested!(
219///     [[1, 2, 3],
220///      [1, 2, 3]],
221///     &device);
222/// assert!(rt::allclose!(&result, &expected));
223/// ```
224///
225/// Please note the above example is only working in RowMajor order. In ColMajor order, the
226/// broadcasting will be done along the other axis:
227///
228/// ```rust
229/// use rstsr::prelude::*;
230/// let mut device = DeviceCpu::default();
231/// device.set_default_order(ColMajor);
232///
233/// let a = rt::tensor_from_nested!([1, 2, 3], &device);
234/// // in col-major, broadcast (3, ) -> (2, 3) will fail:
235/// let result = a.to_broadcast_f(vec![2, 3]);
236/// assert!(result.is_err());
237///
238/// // broadcast (3, ) -> (3, 2) in col-major:
239/// let result = a.to_broadcast(vec![3, 2]);
240/// let expected = rt::tensor_from_nested!(
241///     [[1, 1],
242///      [2, 2],
243///      [3, 3]],
244///     &device);
245/// assert!(rt::allclose!(&result, &expected));
246/// ```
247///
248/// # Panics
249///
250/// - Incompatible shapes to be broadcasted.
251///
252/// # Elaborated examples
253///
254/// ## Broadcasting behavior (in row-major)
255///
256/// This example does not directly call this function `to_broadcast`, but demonstrates the
257/// broadcasting behavior.
258///
259/// ```rust
260/// use rstsr::prelude::*;
261/// let mut device = DeviceCpu::default();
262/// device.set_default_order(RowMajor);
263///
264/// // A      (4d tensor):  8 x 1 x 6 x 1
265/// // B      (3d tensor):      7 x 1 x 5
266/// // ----------------------------------
267/// // Result (4d tensor):  8 x 7 x 6 x 5
268/// let a = rt::arange((48, &device)).into_shape([8, 1, 6, 1]);
269/// let b = rt::arange((35, &device)).into_shape([7, 1, 5]);
270/// let result = &a + &b;
271/// assert_eq!(result.shape(), &[8, 7, 6, 5]);
272///
273/// // A      (2d tensor):  5 x 4
274/// // B      (1d tensor):      1
275/// // --------------------------
276/// // Result (2d tensor):  5 x 4
277/// let a = rt::arange((20, &device)).into_shape([5, 4]);
278/// let b = rt::arange((1, &device)).into_shape([1]);
279/// let result = &a + &b;
280/// assert_eq!(result.shape(), &[5, 4]);
281///
282/// // A      (2d tensor):  5 x 4
283/// // B      (1d tensor):      4
284/// // --------------------------
285/// // Result (2d tensor):  5 x 4
286/// let a = rt::arange((20, &device)).into_shape([5, 4]);
287/// let b = rt::arange((4, &device)).into_shape([4]);
288/// let result = &a + &b;
289/// assert_eq!(result.shape(), &[5, 4]);
290///
291/// // A      (3d tensor):  15 x 3 x 5
292/// // B      (3d tensor):  15 x 1 x 5
293/// // -------------------------------
294/// // Result (3d tensor):  15 x 3 x 5
295/// let a = rt::arange((225, &device)).into_shape([15, 3, 5]);
296/// let b = rt::arange((75, &device)).into_shape([15, 1, 5]);
297/// let result = &a + &b;
298/// assert_eq!(result.shape(), &[15, 3, 5]);
299///
300/// // A      (3d tensor):  15 x 3 x 5
301/// // B      (2d tensor):       3 x 5
302/// // -------------------------------
303/// // Result (3d tensor):  15 x 3 x 5
304/// let a = rt::arange((225, &device)).into_shape([15, 3, 5]);
305/// let b = rt::arange((15, &device)).into_shape([3, 5]);
306/// let result = &a + &b;
307/// assert_eq!(result.shape(), &[15, 3, 5]);
308///
309/// // A      (3d tensor):  15 x 3 x 5
310/// // B      (2d tensor):       3 x 1
311/// // -------------------------------
312/// // Result (3d tensor):  15 x 3 x 5
313/// let a = rt::arange((225, &device)).into_shape([15, 3, 5]);
314/// let b = rt::arange((3, &device)).into_shape([3, 1]);
315/// let result = &a + &b;
316/// assert_eq!(result.shape(), &[15, 3, 5]);
317/// ```
318///
319/// ## Broadcasting behavior (in col-major)
320///
321/// This example does not directly call this function `to_broadcast`, but demonstrates the
322/// broadcasting behavior.
323///
324/// ```rust
325/// use rstsr::prelude::*;
326/// let mut device = DeviceCpu::default();
327/// device.set_default_order(ColMajor);
328///
329/// // A      (4d tensor):  1 x 6 x 1 x 8
330/// // B      (3d tensor):  5 x 1 x 7
331/// // ----------------------------------
332/// // Result (4d tensor):  5 x 6 x 7 x 8
333/// let a = rt::arange((48, &device)).into_shape([1, 6, 1, 8]);
334/// let b = rt::arange((35, &device)).into_shape([5, 1, 7]);
335/// let result = &a + &b;
336/// assert_eq!(result.shape(), &[5, 6, 7, 8]);
337///
338/// // A      (2d tensor):  4 x 5
339/// // B      (1d tensor):  1
340/// // --------------------------
341/// // Result (2d tensor):  4 x 5
342/// let a = rt::arange((20, &device)).into_shape([4, 5]);
343/// let b = rt::arange((1, &device)).into_shape([1]);
344/// let result = &a + &b;
345/// assert_eq!(result.shape(), &[4, 5]);
346///
347/// // A      (2d tensor):  4 x 5
348/// // B      (1d tensor):  4
349/// // --------------------------
350/// // Result (2d tensor):  4 x 5
351/// let a = rt::arange((20, &device)).into_shape([4, 5]);
352/// let b = rt::arange((4, &device)).into_shape([4]);
353/// let result = &a + &b;
354/// assert_eq!(result.shape(), &[4, 5]);
355///
356/// // A      (3d tensor):  5 x 3 x 15
357/// // B      (3d tensor):  5 x 1 x 15
358/// // -------------------------------
359/// // Result (3d tensor):  5 x 3 x 15
360/// let a = rt::arange((225, &device)).into_shape([5, 3, 15]);
361/// let b = rt::arange((75, &device)).into_shape([5, 1, 15]);
362/// let result = &a + &b;
363/// assert_eq!(result.shape(), &[5, 3, 15]);
364///
365/// // A      (3d tensor):  5 x 3 x 15
366/// // B      (2d tensor):  5 x 3
367/// // -------------------------------
368/// // Result (3d tensor):  5 x 3 x 15
369/// let a = rt::arange((225, &device)).into_shape([5, 3, 15]);
370/// let b = rt::arange((15, &device)).into_shape([5, 3]);
371/// let result = &a + &b;
372/// assert_eq!(result.shape(), &[5, 3, 15]);
373///
374/// // A      (3d tensor):  5 x 3 x 15
375/// // B      (2d tensor):  1 x 3
376/// // -------------------------------
377/// // Result (3d tensor):  5 x 3 x 15
378/// let a = rt::arange((225, &device)).into_shape([5, 3, 15]);
379/// let b = rt::arange((3, &device)).into_shape([1, 3]);
380/// let result = &a + &b;
381/// assert_eq!(result.shape(), &[5, 3, 15]);
382/// ```
383///
384/// # See also
385///
386/// ## Detailed broadcasting rules
387///
388/// - Python Array API standard: [Broadcasting rules](https://data-apis.org/array-api/latest/API_specification/broadcasting.html)
389/// - NumPy: [Broadcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html)
390///
391/// ## Similar function from other crates/libraries
392///
393/// - Python Array API standard: [`broadcast_to`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.broadcast_to.html)
394/// - NumPy: [`numpy.broadcast_to`](https://numpy.org/doc/stable/reference/generated/numpy.broadcast_to.html)
395/// - ndarray: [`ndarray::broadcast`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html#method.broadcast)
396///
397/// ## Related functions in RSTSR
398///
399/// - [`broadcast_arrays`]: Broadcasts any number of arrays against each other.
400///
401/// ## Variants of this function
402///
403/// - [`to_broadcast`]: Standard version.
404/// - [`to_broadcast_f`]: Fallible version.
405/// - [`into_broadcast`]: Consuming version that takes ownership of the input tensor.
406/// - [`into_broadcast_f`]: Consuming and fallible version, actual implementation.
407/// - [`broadcast_to`]: Alias for `to_broadcast` (name of Python Array API standard).
408/// - Associated methods on [`TensorAny`]:
409///
410///   - [`TensorAny::to_broadcast`]
411///   - [`TensorAny::to_broadcast_f`]
412///   - [`TensorAny::into_broadcast`]
413///   - [`TensorAny::into_broadcast_f`]
414///   - [`TensorAny::broadcast_to`]
415pub fn to_broadcast<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>, shape: D2) -> TensorView<'_, T, B, D2>
416where
417    D: DimAPI + DimMaxAPI<D2, Max = D2>,
418    D2: DimAPI,
419    R: DataAPI<Data = B::Raw>,
420    B: DeviceAPI<T>,
421{
422    into_broadcast_f(tensor.view(), shape).rstsr_unwrap()
423}
424
425/// Broadcasts an array to a specified shape.
426///
427/// # See also
428///
429/// Refer to [`to_broadcast`] or [`into_broadcast`] for detailed documentation.
430pub fn broadcast_to<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>, shape: D2) -> TensorView<'_, T, B, D2>
431where
432    D: DimAPI + DimMaxAPI<D2, Max = D2>,
433    D2: DimAPI,
434    R: DataAPI<Data = B::Raw>,
435    B: DeviceAPI<T>,
436{
437    into_broadcast_f(tensor.view(), shape).rstsr_unwrap()
438}
439
440/// Broadcasts an array to a specified shape.
441///
442/// # See also
443///
444/// Refer to [`to_broadcast`] or [`into_broadcast`] for detailed documentation.
445pub fn to_broadcast_f<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>, shape: D2) -> Result<TensorView<'_, T, B, D2>>
446where
447    D: DimAPI + DimMaxAPI<D2, Max = D2>,
448    D2: DimAPI,
449    R: DataAPI<Data = B::Raw>,
450    B: DeviceAPI<T>,
451{
452    into_broadcast_f(tensor.view(), shape)
453}
454
455/// Broadcasts an array to a specified shape.
456///
457/// <div class="warning">
458///
459/// **Row/Column Major Notice**
460///
461/// This function behaves differently on default orders ([`RowMajor`] and [`ColMajor`]) of device.
462///
463/// </div>
464///
465/// # Parameters
466///
467/// - `tensor`: [`TensorAny<R, T, B, D>`]
468///
469///   - The input tensor to be broadcasted.
470///   - Ownership of input tensor is taken.
471///
472/// - `shape`: impl [`DimAPI`]
473///
474///   - The shape of the desired output tensor after broadcasting.
475///   - Please note [`IxD`] (`Vec<usize>`) and [`Ix<N>`] (`[usize; N]`) behaves differently here.
476///     [`IxD`] will give dynamic shape tensor, while [`Ix<N>`] will give static shape tensor.
477///
478///
479/// # Returns
480///
481/// - [`TensorAny<R, T, B, D2>`]
482///
483///   - The tensor with the given shape. It is typically not contiguous (perform [`to_contig`]
484///     afterwards if requires a contiguous owned tensor).
485///   - Furthermore, more than one element of a broadcasted tensor may refer to a single memory
486///     location (zero strides at the broadcasted axes).
487///   - Ownership of the returned tensor is transferred from the input tensor. Only the layout is
488///     modified; the underlying data remains unchanged.
489///
490/// # See also
491///
492/// Refer to [`to_broadcast`] for more detailed documentation.
493pub fn into_broadcast<R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, shape: D2) -> TensorAny<R, T, B, D2>
494where
495    R: DataAPI<Data = B::Raw>,
496    B: DeviceAPI<T>,
497    D: DimAPI + DimMaxAPI<D2, Max = D2>,
498    D2: DimAPI,
499{
500    into_broadcast_f(tensor, shape).rstsr_unwrap()
501}
502
503impl<R, T, B, D> TensorAny<R, T, B, D>
504where
505    R: DataAPI<Data = B::Raw>,
506    B: DeviceAPI<T>,
507    D: DimAPI,
508{
509    /// Broadcasts an array to a specified shape.
510    ///
511    /// # See also
512    ///
513    /// Refer to [`to_broadcast`] or [`into_broadcast`] for detailed documentation.
514    pub fn to_broadcast<D2>(&self, shape: D2) -> TensorView<'_, T, B, D2>
515    where
516        D2: DimAPI,
517        D: DimMaxAPI<D2, Max = D2>,
518    {
519        to_broadcast(self, shape)
520    }
521
522    /// Broadcasts an array to a specified shape.
523    ///
524    /// # See also
525    ///
526    /// Refer to [`to_broadcast`] or [`into_broadcast`] for detailed documentation.
527    pub fn broadcast_to<D2>(&self, shape: D2) -> TensorView<'_, T, B, D2>
528    where
529        D2: DimAPI,
530        D: DimMaxAPI<D2, Max = D2>,
531    {
532        broadcast_to(self, shape)
533    }
534
535    /// Broadcasts an array to a specified shape.
536    ///
537    /// # See also
538    ///
539    /// Refer to [`to_broadcast`] or [`into_broadcast`] for detailed documentation.
540    pub fn to_broadcast_f<D2>(&self, shape: D2) -> Result<TensorView<'_, T, B, D2>>
541    where
542        D2: DimAPI,
543        D: DimMaxAPI<D2, Max = D2>,
544    {
545        to_broadcast_f(self, shape)
546    }
547
548    /// Broadcasts an array to a specified shape.
549    ///
550    /// # See also
551    ///
552    /// Refer to [`to_broadcast`] or [`into_broadcast`] for detailed documentation.
553    pub fn into_broadcast<D2>(self, shape: D2) -> TensorAny<R, T, B, D2>
554    where
555        D2: DimAPI,
556        D: DimMaxAPI<D2, Max = D2>,
557    {
558        into_broadcast(self, shape)
559    }
560
561    /// Broadcasts an array to a specified shape.
562    ///
563    /// # See also
564    ///
565    /// Refer to [`to_broadcast`] or [`into_broadcast`] for detailed documentation.
566    pub fn into_broadcast_f<D2>(self, shape: D2) -> Result<TensorAny<R, T, B, D2>>
567    where
568        D2: DimAPI,
569        D: DimMaxAPI<D2, Max = D2>,
570    {
571        into_broadcast_f(self, shape)
572    }
573}
574
575/* #endregion */
576
577#[cfg(test)]
578mod tests {
579
580    #[test]
581    #[rustfmt::skip]
582    fn doc_broadcast_to() {
583        use rstsr::prelude::*;
584        let mut device = DeviceCpu::default();
585        device.set_default_order(RowMajor);
586
587        let a = rt::tensor_from_nested!([1, 2, 3], &device);
588
589        // broadcast (3, ) -> (2, 3) in row-major:
590        let result = a.to_broadcast(vec![2, 3]);
591        let expected = rt::tensor_from_nested!(
592            [[1, 2, 3],
593             [1, 2, 3]],
594            &device);
595        assert!(rt::allclose!(&result, &expected));
596    }
597
598    #[test]
599    #[rustfmt::skip]
600    fn doc_broadcast_to_col_major() {
601        use rstsr::prelude::*;
602        let mut device = DeviceCpu::default();
603        device.set_default_order(ColMajor);
604
605        let a = rt::tensor_from_nested!([1, 2, 3], &device);
606        // in col-major, broadcast (3, ) -> (2, 3) will fail:
607        let result = a.to_broadcast_f(vec![2, 3]);
608        assert!(result.is_err());
609
610        // broadcast (3, ) -> (3, 2) in col-major:
611        let result = a.to_broadcast(vec![3, 2]);
612        let expected = rt::tensor_from_nested!(
613            [[1, 1],
614             [2, 2],
615             [3, 3]],
616            &device);
617        assert!(rt::allclose!(&result, &expected));
618    }
619
620    #[test]
621    fn doc_broadcast_to_elaborated_row_major() {
622        use rstsr::prelude::*;
623        let mut device = DeviceCpu::default();
624        device.set_default_order(RowMajor);
625
626        // A      (4d tensor):  8 x 1 x 6 x 1
627        // B      (3d tensor):      7 x 1 x 5
628        // ----------------------------------
629        // Result (4d tensor):  8 x 7 x 6 x 5
630        let a = rt::arange((48, &device)).into_shape([8, 1, 6, 1]);
631        let b = rt::arange((35, &device)).into_shape([7, 1, 5]);
632        let result = &a + &b;
633        assert_eq!(result.shape(), &[8, 7, 6, 5]);
634
635        // A      (2d tensor):  5 x 4
636        // B      (1d tensor):      1
637        // --------------------------
638        // Result (2d tensor):  5 x 4
639        let a = rt::arange((20, &device)).into_shape([5, 4]);
640        let b = rt::arange((1, &device)).into_shape([1]);
641        let result = &a + &b;
642        assert_eq!(result.shape(), &[5, 4]);
643
644        // A      (2d tensor):  5 x 4
645        // B      (1d tensor):      4
646        // --------------------------
647        // Result (2d tensor):  5 x 4
648        let a = rt::arange((20, &device)).into_shape([5, 4]);
649        let b = rt::arange((4, &device)).into_shape([4]);
650        let result = &a + &b;
651        assert_eq!(result.shape(), &[5, 4]);
652
653        // A      (3d tensor):  15 x 3 x 5
654        // B      (3d tensor):  15 x 1 x 5
655        // -------------------------------
656        // Result (3d tensor):  15 x 3 x 5
657        let a = rt::arange((225, &device)).into_shape([15, 3, 5]);
658        let b = rt::arange((75, &device)).into_shape([15, 1, 5]);
659        let result = &a + &b;
660        assert_eq!(result.shape(), &[15, 3, 5]);
661
662        // A      (3d tensor):  15 x 3 x 5
663        // B      (2d tensor):       3 x 5
664        // -------------------------------
665        // Result (3d tensor):  15 x 3 x 5
666        let a = rt::arange((225, &device)).into_shape([15, 3, 5]);
667        let b = rt::arange((15, &device)).into_shape([3, 5]);
668        let result = &a + &b;
669        assert_eq!(result.shape(), &[15, 3, 5]);
670
671        // A      (3d tensor):  15 x 3 x 5
672        // B      (2d tensor):       3 x 1
673        // -------------------------------
674        // Result (3d tensor):  15 x 3 x 5
675        let a = rt::arange((225, &device)).into_shape([15, 3, 5]);
676        let b = rt::arange((3, &device)).into_shape([3, 1]);
677        let result = &a + &b;
678        assert_eq!(result.shape(), &[15, 3, 5]);
679    }
680
681    #[test]
682    fn doc_broadcast_to_elaborated_col_major() {
683        use rstsr::prelude::*;
684        let mut device = DeviceCpu::default();
685        device.set_default_order(ColMajor);
686
687        // A      (4d tensor):  1 x 6 x 1 x 8
688        // B      (3d tensor):  5 x 1 x 7
689        // ----------------------------------
690        // Result (4d tensor):  5 x 6 x 7 x 8
691        let a = rt::arange((48, &device)).into_shape([1, 6, 1, 8]);
692        let b = rt::arange((35, &device)).into_shape([5, 1, 7]);
693        let result = &a + &b;
694        assert_eq!(result.shape(), &[5, 6, 7, 8]);
695
696        // A      (2d tensor):  4 x 5
697        // B      (1d tensor):  1
698        // --------------------------
699        // Result (2d tensor):  4 x 5
700        let a = rt::arange((20, &device)).into_shape([4, 5]);
701        let b = rt::arange((1, &device)).into_shape([1]);
702        let result = &a + &b;
703        assert_eq!(result.shape(), &[4, 5]);
704
705        // A      (2d tensor):  4 x 5
706        // B      (1d tensor):  4
707        // --------------------------
708        // Result (2d tensor):  4 x 5
709        let a = rt::arange((20, &device)).into_shape([4, 5]);
710        let b = rt::arange((4, &device)).into_shape([4]);
711        let result = &a + &b;
712        assert_eq!(result.shape(), &[4, 5]);
713
714        // A      (3d tensor):  5 x 3 x 15
715        // B      (3d tensor):  5 x 1 x 15
716        // -------------------------------
717        // Result (3d tensor):  5 x 3 x 15
718        let a = rt::arange((225, &device)).into_shape([5, 3, 15]);
719        let b = rt::arange((75, &device)).into_shape([5, 1, 15]);
720        let result = &a + &b;
721        assert_eq!(result.shape(), &[5, 3, 15]);
722
723        // A      (3d tensor):  5 x 3 x 15
724        // B      (2d tensor):  5 x 3
725        // -------------------------------
726        // Result (3d tensor):  5 x 3 x 15
727        let a = rt::arange((225, &device)).into_shape([5, 3, 15]);
728        let b = rt::arange((15, &device)).into_shape([5, 3]);
729        let result = &a + &b;
730        assert_eq!(result.shape(), &[5, 3, 15]);
731
732        // A      (3d tensor):  5 x 3 x 15
733        // B      (2d tensor):  1 x 3
734        // -------------------------------
735        // Result (3d tensor):  5 x 3 x 15
736        let a = rt::arange((225, &device)).into_shape([5, 3, 15]);
737        let b = rt::arange((3, &device)).into_shape([1, 3]);
738        let result = &a + &b;
739        assert_eq!(result.shape(), &[5, 3, 15]);
740    }
741
742    #[test]
743    #[rustfmt::skip]
744    fn doc_broadcast_arrays_row_major() {
745        use rstsr::prelude::*;
746        let mut device = DeviceCpu::default();
747        device.set_default_order(RowMajor);
748
749        let a = rt::asarray((vec![1, 2, 3], &device)).into_shape([3]);
750        let b = rt::asarray((vec![4, 5], &device)).into_shape([2, 1]);
751
752        let result = rt::broadcast_arrays(vec![a, b]);
753        let expected_a = rt::tensor_from_nested!(
754            [[1, 2, 3],
755             [1, 2, 3]],
756            &device);
757        let expected_b = rt::tensor_from_nested!(
758            [[4, 4, 4],
759             [5, 5, 5]],
760            &device);
761        assert!(rt::allclose!(&result[0], &expected_a));
762        assert!(rt::allclose!(&result[1], &expected_b));
763    }
764
765    #[test]
766    #[rustfmt::skip]
767    fn doc_broadcast_arrays_col_major() {
768        use rstsr::prelude::*;
769        let mut device = DeviceCpu::default();
770        device.set_default_order(ColMajor);
771
772        let a = rt::asarray((vec![1, 2, 3], &device)).into_shape([1, 3]);
773        let b = rt::asarray((vec![4, 5], &device)).into_shape([2, 1]);
774
775        let result = rt::broadcast_arrays(vec![a, b]);
776        let expected_a = rt::tensor_from_nested!(
777            [[1, 2, 3],
778             [1, 2, 3]],
779            &device);
780        let expected_b = rt::tensor_from_nested!(
781            [[4, 4, 4],
782             [5, 5, 5]],
783            &device);
784        assert!(rt::allclose!(&result[0], &expected_a));
785        assert!(rt::allclose!(&result[1], &expected_b));
786    }
787}