Skip to main content

rstsr_core/tensor/manuplication/
broadcast.rs

1use crate::prelude_dev::*;
2
3/* #region broadcast_shapes */
4
5/// Broadcasts shapes against each other and returns the resulting shape.
6///
7/// This function computes the shape that results from broadcasting the given
8/// shapes against one another according to the standard broadcasting rules.
9/// It does not require actual tensor data, only the shapes.
10///
11/// <div class="warning">
12///
13/// **Row/Column Major Notice**
14///
15/// This function behaves differently on default orders ([`RowMajor`] and [`ColMajor`]) of device.
16///
17/// </div>
18///
19/// # Parameters
20///
21/// - `shapes`: A slice of shapes, where each shape is [`IxD`] (`Vec<usize>`).
22/// - `order`: The memory order ([`RowMajor`] or [`ColMajor`]) that determines how broadcasting
23///   rules are applied.
24///
25/// # Returns
26///
27/// - `IxD`: The resulting shape after broadcasting all input shapes together.
28///
29/// # Examples
30///
31/// Broadcast two shapes in row-major order:
32///
33/// ```rust
34/// # use rstsr::prelude::*;
35/// // A      (4d array):  8 x 1 x 6 x 1
36/// // B      (3d array):      7 x 1 x 5
37/// // ---------------------------------
38/// // Result (4d array):  8 x 7 x 6 x 5
39/// let shape1 = vec![8, 1, 6, 1];
40/// let shape2 = vec![7, 1, 5];
41/// let result = rt::broadcast_shapes(&[shape1, shape2], RowMajor);
42/// println!("{:?}", result);
43/// // [8, 7, 6, 5]
44/// assert_eq!(result, vec![8, 7, 6, 5]);
45/// ```
46///
47/// In col-major order, the broadcasting is applied from the left instead of right:
48///
49/// ```rust
50/// # use rstsr::prelude::*;
51/// // A      (4d array):  1 x 6 x 1 x 8
52/// // B      (3d array):  5 x 1 x 7
53/// // ---------------------------------
54/// // Result (4d array):  5 x 6 x 7 x 8
55/// let shape1 = vec![1, 6, 1, 8];
56/// let shape2 = vec![5, 1, 7];
57/// let result = rt::broadcast_shapes(&[shape1, shape2], ColMajor);
58/// println!("{:?}", result);
59/// // [5, 6, 7, 8]
60/// assert_eq!(result, vec![5, 6, 7, 8]);
61/// ```
62///
63/// Broadcast multiple shapes:
64///
65/// ```rust
66/// # use rstsr::prelude::*;
67/// // Three shapes: (1,), (3, 1), (3, 4) -> (3, 4)
68/// let shapes = vec![vec![1], vec![3, 1], vec![3, 4]];
69/// let result = rt::broadcast_shapes(&shapes, RowMajor);
70/// println!("{:?}", result);
71/// // [3, 4]
72/// assert_eq!(result, vec![3, 4]);
73/// ```
74///
75/// # Notes of API accordance
76///
77/// - Array-API: `broadcast_shapes(*shapes)` ([`broadcast_shapes`](https://data-apis.org/array-api/2024.12/API_specification/generated/array_api.broadcast_shapes.html))
78/// - NumPy: `numpy.broadcast_shapes(*shapes)` ([`numpy.broadcast_shapes`](https://numpy.org/doc/stable/reference/generated/numpy.broadcast_shapes.html))
79/// - RSTSR: `rt::broadcast_shapes(shapes, order)`
80///
81/// # Panics
82///
83/// Panics if the shapes cannot be broadcast together.
84///
85/// For a fallible version, use [`broadcast_shapes_f`].
86///
87/// # See also
88///
89/// ## Variants of this function
90///
91/// - [`broadcast_shapes_f`]: Fallible version, actual implementation.
92///
93/// ## Related functions in RSTSR
94///
95/// - [`broadcast_arrays`]: Broadcasts any number of arrays against each other.
96/// - [`broadcast_to`]: Broadcasts an array to a specified shape.
97pub fn broadcast_shapes(shapes: &[IxD], order: FlagOrder) -> IxD {
98    broadcast_shapes_f(shapes, order).rstsr_unwrap()
99}
100
101/// Broadcasts shapes against each other and returns the resulting shape.
102///
103/// See also [`broadcast_shapes`].
104pub fn broadcast_shapes_f(shapes: &[IxD], order: FlagOrder) -> Result<IxD> {
105    // fast return if empty or single shape
106    if shapes.is_empty() {
107        return Ok(vec![]);
108    }
109    if shapes.len() == 1 {
110        return Ok(shapes[0].clone());
111    }
112
113    // iteratively broadcast all shapes
114    let mut result = shapes[0].clone();
115    for shape in shapes.iter().skip(1) {
116        let (new_result, _, _) = broadcast_shape(shape, &result, order)?;
117        result = new_result;
118    }
119    Ok(result)
120}
121
122/* #endregion */
123
124/* #region broadcast_arrays */
125
126/// Broadcasts any number of arrays against each other.
127///
128/// <div class="warning">
129///
130/// **Row/Column Major Notice**
131///
132/// This function behaves differently on default orders ([`RowMajor`] and [`ColMajor`]) of device.
133///
134/// </div>
135///
136/// # Parameters
137///
138/// - `tensors`: [`Vec<TensorAny<R, T, B, IxD>>`](TensorAny)
139///
140///   - The tensors to be broadcasted.
141///   - All tensors must be on the same device, and share the same ownerships.
142///   - This function takes ownership of the input tensors. If you want to obtain broadcasted views,
143///     you need to create a new vector of views first.
144///   - This function only accepts dynamic shape tensors ([`IxD`]).
145///
146/// # Returns
147///
148/// - [`Vec<TensorAny<R, T, B, IxD>>`](TensorAny)
149///
150///   - A vector of broadcasted tensors. Each tensor has the same shape after broadcasting.
151///   - The ownership of the underlying data is moved from the input tensors to the output tensors.
152///   - The tensors are typically not contiguous (with zero strides at the broadcasted axes).
153///     Writing values to broadcasted tensors is dangerous, but RSTSR will generally not panic on
154///     this behavior. Perform [`to_contig`] afterwards if requires owned contiguous tensors.
155///
156/// # Examples
157///
158/// The following example demonstrates how to use `broadcast_arrays` to broadcast two tensors:
159///
160/// ```rust
161/// # use rstsr::prelude::*;
162/// # let mut device = DeviceCpu::default();
163/// # device.set_default_order(RowMajor);
164/// let a = rt::asarray((vec![1, 2, 3], &device)).into_shape([3]);
165/// println!("{a}");
166/// // [ 1 2 3]
167/// let b = rt::asarray((vec![4, 5], &device)).into_shape([2, 1]);
168/// println!("{b}");
169/// // [[ 4]
170/// //  [ 5]]
171///
172/// let result = rt::broadcast_arrays(vec![a, b]);
173/// println!("broadcasted a:\n{:}", result[0]);
174/// // [[ 1 2 3]
175/// //  [ 1 2 3]]
176/// println!("broadcasted b:\n{:}", result[1]);
177/// // [[ 4 4 4]
178/// //  [ 5 5 5]]
179/// ```
180///
181/// Please note that the above code only works in [RowMajor].
182///
183/// For [ColMajor] order, the broadcasting will fail, because the broadcasting rules are applied
184/// differently, shapes are incompatible (for col-major, broadcast comparison starts from left
185/// instead of row-major's right):
186///
187/// ```rust,should_panic
188/// # use rstsr::prelude::*;
189/// # let mut device = DeviceCpu::default();
190/// device.set_default_order(ColMajor);
191/// let a = rt::asarray((vec![1, 2, 3], &device)).into_shape([3]);
192/// let b = rt::asarray((vec![4, 5], &device)).into_shape([2, 1]);
193/// let result = rt::broadcast_arrays(vec![a, b]);
194/// ```
195///
196/// You need to make the following changes to let [ColMajor] case work:
197///
198/// ```rust
199/// # use rstsr::prelude::*;
200/// # let mut device = DeviceCpu::default();
201/// device.set_default_order(ColMajor);
202/// let a = rt::asarray((vec![1, 2, 3], &device)).into_shape([1, 3]);
203/// let b = rt::asarray((vec![4, 5], &device)).into_shape([2, 1]);
204/// let result = rt::broadcast_arrays_f(vec![a, b]);
205/// ```
206///
207/// # Notes of API accordance
208///
209/// - Array-API: `broadcast_arrays(*arrays)` ([`broadcast_arrays`](https://data-apis.org/array-api/2024.12/API_specification/generated/array_api.broadcast_arrays.html))
210/// - NumPy: `numpy.broadcast_arrays(*args, subok=False)` ([`numpy.broadcast_arrays`](https://numpy.org/doc/stable/reference/generated/numpy.broadcast_arrays.html))
211/// - RSTSR: `rt::broadcast_arrays(tensors)`
212///
213/// # Panics
214///
215/// Panics if the tensors cannot be broadcast together, or if tensors are on different devices.
216///
217/// For a fallible version, use [`broadcast_arrays_f`].
218///
219/// # See also
220///
221/// ## Related functions in RSTSR
222///
223/// - [`to_broadcast`]: Broadcasts a single array to a specified shape.
224///
225/// ## Variants of this function
226///
227/// - [`broadcast_arrays_f`]: Fallible version, actual implementation.
228pub fn broadcast_arrays<R, T, B>(tensors: Vec<TensorAny<R, T, B, IxD>>) -> Vec<TensorAny<R, T, B, IxD>>
229where
230    R: DataAPI<Data = B::Raw>,
231    B: DeviceAPI<T>,
232{
233    broadcast_arrays_f(tensors).rstsr_unwrap()
234}
235
236/// Broadcasts any number of arrays against each other.
237///
238/// See also [`broadcast_arrays`].
239pub fn broadcast_arrays_f<R, T, B>(tensors: Vec<TensorAny<R, T, B, IxD>>) -> Result<Vec<TensorAny<R, T, B, IxD>>>
240where
241    R: DataAPI<Data = B::Raw>,
242    B: DeviceAPI<T>,
243{
244    // fast return if there is only zero/one tensor
245    if tensors.len() <= 1 {
246        return Ok(tensors);
247    }
248    let device_b = tensors[0].device().clone();
249    let default_order = device_b.default_order();
250    let mut shape_b = tensors[0].shape().clone();
251    for tensor in tensors.iter().skip(1) {
252        rstsr_assert!(device_b.same_device(tensor.device()), DeviceMismatch)?;
253        let shape = tensor.shape();
254        let (shape, _, _) = broadcast_shape(shape, &shape_b, default_order)?;
255        shape_b = shape;
256    }
257    let mut tensors_new = Vec::with_capacity(tensors.len());
258    for tensor in tensors {
259        let tensor = into_broadcast_f(tensor, shape_b.clone())?;
260        tensors_new.push(tensor);
261    }
262    return Ok(tensors_new);
263}
264
265/* #endregion */
266
267/* #region broadcast_to */
268
269/// Broadcasts an array to a specified shape.
270///
271/// See also [`to_broadcast`].
272pub fn into_broadcast_f<R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, shape: D2) -> Result<TensorAny<R, T, B, D2>>
273where
274    R: DataAPI<Data = B::Raw>,
275    B: DeviceAPI<T>,
276    D: DimAPI + DimMaxAPI<D2, Max = D2>,
277    D2: DimAPI,
278{
279    let shape1 = tensor.shape();
280    let shape2 = &shape;
281    let default_order = tensor.device().default_order();
282    let (shape, tp1, _) = broadcast_shape(shape1, shape2, default_order)?;
283    // For broadcast_to, the result shape must match the target shape exactly
284    // This ensures we cannot broadcast from higher dimensions to lower dimensions
285    rstsr_assert_eq!(
286        shape.as_ref(),
287        shape2.as_ref(),
288        InvalidLayout,
289        "Cannot broadcast to a shape with fewer dimensions or incompatible size."
290    )?;
291    let (storage, layout) = tensor.into_raw_parts();
292    let layout = update_layout_by_shape(&layout, &shape, &tp1, default_order)?;
293    unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
294}
295
296/// Broadcasts an array to a specified shape.
297///
298/// <div class="warning">
299///
300/// **Row/Column Major Notice**
301///
302/// This function behaves differently on default orders ([`RowMajor`] and [`ColMajor`]) of device.
303///
304/// </div>
305///
306/// # Parameters
307///
308/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
309///
310///   - The input tensor to be broadcasted.
311///   - Note on variant [`into_broadcast`]: This takes ownership [`Tensor<R, T, B, D>`] of input
312///     tensor, and will not perform change to underlying data, only layout changes.
313///
314/// - `shape`: impl [`DimAPI`]
315///
316///   - The shape of the desired output tensor after broadcasting.
317///   - Please note [`IxD`] (`Vec<usize>`) and [`Ix<N>`] (`[usize; N]`) behaves differently here.
318///     [`IxD`] will give dynamic shape tensor, while [`Ix<N>`] will give static shape tensor.
319///     Unless you are handling static-shape tensors, always use [`IxD`] (`Vec<usize>`) to specify
320///     shape.
321///
322/// # Returns
323///
324/// - [`TensorView<'_, T, B, D2>`]
325///
326///   - A readonly view on the original tensor with the given shape. It is typically not contiguous
327///     (perform [`to_contig`] afterwards if you require contiguous owned tensors).
328///   - Furthermore, more than one element of a broadcasted tensor may refer to a single memory
329///     location (zero strides at the broadcasted axes). Writing values to broadcasted tensors is
330///     dangerous, but RSTSR will generally not panic on this behavior.
331///   - If you want to convert the tensor itself (taking the ownership instead of returning view),
332///     use [`into_broadcast`] instead.
333///
334/// # Examples
335///
336/// The following example demonstrates how to use `to_broadcast` to broadcast a 1-D tensor
337/// (3-element vector) to a 2-D tensor (2x3 matrix) by repeating the original data along a new axis.
338///
339/// ```rust
340/// # use rstsr::prelude::*;
341/// # let mut device = DeviceCpu::default();
342/// # device.set_default_order(RowMajor);
343/// // broadcast (3, ) -> (2, 3) in row-major:
344/// let a = rt::tensor_from_nested!([1, 2, 3], &device);
345/// let result = a.to_broadcast(vec![2, 3]);
346/// println!("{result}");
347/// // [[ 1 2 3]
348/// //  [ 1 2 3]]
349/// ```
350///
351/// Please note the above example is only working in RowMajor order. In ColMajor order, the
352/// broadcasting will be done along the other axis:
353///
354/// ```rust,should_panic
355/// # use rstsr::prelude::*;
356/// # let mut device = DeviceCpu::default();
357/// device.set_default_order(ColMajor);
358/// let a = rt::tensor_from_nested!([1, 2, 3], &device);
359/// // in col-major, broadcast (3, ) -> (2, 3) will panic:
360/// let result = a.to_broadcast(vec![2, 3]);
361/// ```
362///
363/// The correct shape for col-major is `(3, 2)` instead:
364///
365/// ```rust
366/// # use rstsr::prelude::*;
367/// # let mut device = DeviceCpu::default();
368/// device.set_default_order(ColMajor);
369/// let a = rt::tensor_from_nested!([1, 2, 3], &device);
370/// // broadcast (3, ) -> (3, 2) in col-major:
371/// let result = a.to_broadcast(vec![3, 2]);
372/// println!("{result}");
373/// // [[ 1 1]
374/// //  [ 2 2]
375/// //  [ 3 3]]
376/// ```
377///
378/// # Tips on common compilation errors
379///
380/// You may encounter compilation errors related to shape types like:
381///
382/// ```compile_fail
383/// # use rstsr::prelude::*;
384/// # let mut device = DeviceCpu::default();
385/// # device.set_default_order(RowMajor);
386/// let input_array: Tensor<i32, _> = rt::asarray((0, &device));
387/// let result = rt::broadcast_to(&input_array, [3]);
388/// ```
389///
390/// ```text
391///  33 |         let result = rt::broadcast_to(&input_array, [3]);
392///     |                      ---------------- ^^^^^^^^^^^^ expected `[usize; 1]`, found `Vec<usize>`
393///     |                      |
394///     |                      required by a bound introduced by this call
395///     |
396///     = note: expected array `[usize; 1]`
397///               found struct `std::vec::Vec<usize>`
398/// note: required by a bound in `rstsr_core::prelude_dev::to_broadcast`
399///    --> rstsr-core/src/tensor/manuplication/broadcast.rs:427:31
400///     |
401/// 425 | pub fn to_broadcast<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>, shape: D2) -> TensorView<'_, T, B, D2>
402///     |        ------------ required by a bound in this function
403/// 426 | where
404/// 427 |     D: DimAPI + DimMaxAPI<D2, Max = D2>,
405///     |                               ^^^^^^^^ required by this bound in `to_broadcast`
406/// ```
407///
408/// To resolve this, you are suggested to use [`IxD`] (`Vec<usize>`) to specify the shape, instead
409/// of [`Ix<N>`] (`[usize; N]`), unless you are an advanced API caller to handle static-shape
410/// tensors.
411///
412/// ```rust
413/// # use rstsr::prelude::*;
414/// # let mut device = DeviceCpu::default();
415/// # device.set_default_order(RowMajor);
416/// let input_array: Tensor<i32, _> = rt::asarray((0, &device));
417/// let result = rt::broadcast_to(&input_array, vec![3]);
418/// ```
419///
420/// # Elaborated examples
421///
422/// ## Broadcasting behavior (in row-major)
423///
424/// This example does not directly call this function `to_broadcast`, but demonstrates the
425/// broadcasting behavior.
426///
427/// ```rust
428/// use rstsr::prelude::*;
429/// let mut device = DeviceCpu::default();
430/// device.set_default_order(RowMajor);
431///
432/// // A      (4d tensor):  8 x 1 x 6 x 1
433/// // B      (3d tensor):      7 x 1 x 5
434/// // ----------------------------------
435/// // Result (4d tensor):  8 x 7 x 6 x 5
436/// let a = rt::arange((48, &device)).into_shape([8, 1, 6, 1]);
437/// let b = rt::arange((35, &device)).into_shape([7, 1, 5]);
438/// let result = &a + &b;
439/// assert_eq!(result.shape(), &[8, 7, 6, 5]);
440///
441/// // A      (2d tensor):  5 x 4
442/// // B      (1d tensor):      1
443/// // --------------------------
444/// // Result (2d tensor):  5 x 4
445/// let a = rt::arange((20, &device)).into_shape([5, 4]);
446/// let b = rt::arange((1, &device)).into_shape([1]);
447/// let result = &a + &b;
448/// assert_eq!(result.shape(), &[5, 4]);
449///
450/// // A      (2d tensor):  5 x 4
451/// // B      (1d tensor):      4
452/// // --------------------------
453/// // Result (2d tensor):  5 x 4
454/// let a = rt::arange((20, &device)).into_shape([5, 4]);
455/// let b = rt::arange((4, &device)).into_shape([4]);
456/// let result = &a + &b;
457/// assert_eq!(result.shape(), &[5, 4]);
458///
459/// // A      (3d tensor):  15 x 3 x 5
460/// // B      (3d tensor):  15 x 1 x 5
461/// // -------------------------------
462/// // Result (3d tensor):  15 x 3 x 5
463/// let a = rt::arange((225, &device)).into_shape([15, 3, 5]);
464/// let b = rt::arange((75, &device)).into_shape([15, 1, 5]);
465/// let result = &a + &b;
466/// assert_eq!(result.shape(), &[15, 3, 5]);
467///
468/// // A      (3d tensor):  15 x 3 x 5
469/// // B      (2d tensor):       3 x 5
470/// // -------------------------------
471/// // Result (3d tensor):  15 x 3 x 5
472/// let a = rt::arange((225, &device)).into_shape([15, 3, 5]);
473/// let b = rt::arange((15, &device)).into_shape([3, 5]);
474/// let result = &a + &b;
475/// assert_eq!(result.shape(), &[15, 3, 5]);
476///
477/// // A      (3d tensor):  15 x 3 x 5
478/// // B      (2d tensor):       3 x 1
479/// // -------------------------------
480/// // Result (3d tensor):  15 x 3 x 5
481/// let a = rt::arange((225, &device)).into_shape([15, 3, 5]);
482/// let b = rt::arange((3, &device)).into_shape([3, 1]);
483/// let result = &a + &b;
484/// assert_eq!(result.shape(), &[15, 3, 5]);
485/// ```
486///
487/// ## Broadcasting behavior (in col-major)
488///
489/// This example does not directly call this function `to_broadcast`, but demonstrates the
490/// broadcasting behavior.
491///
492/// ```rust
493/// use rstsr::prelude::*;
494/// let mut device = DeviceCpu::default();
495/// device.set_default_order(ColMajor);
496///
497/// // A      (4d tensor):  1 x 6 x 1 x 8
498/// // B      (3d tensor):  5 x 1 x 7
499/// // ----------------------------------
500/// // Result (4d tensor):  5 x 6 x 7 x 8
501/// let a = rt::arange((48, &device)).into_shape([1, 6, 1, 8]);
502/// let b = rt::arange((35, &device)).into_shape([5, 1, 7]);
503/// let result = &a + &b;
504/// assert_eq!(result.shape(), &[5, 6, 7, 8]);
505///
506/// // A      (2d tensor):  4 x 5
507/// // B      (1d tensor):  1
508/// // --------------------------
509/// // Result (2d tensor):  4 x 5
510/// let a = rt::arange((20, &device)).into_shape([4, 5]);
511/// let b = rt::arange((1, &device)).into_shape([1]);
512/// let result = &a + &b;
513/// assert_eq!(result.shape(), &[4, 5]);
514///
515/// // A      (2d tensor):  4 x 5
516/// // B      (1d tensor):  4
517/// // --------------------------
518/// // Result (2d tensor):  4 x 5
519/// let a = rt::arange((20, &device)).into_shape([4, 5]);
520/// let b = rt::arange((4, &device)).into_shape([4]);
521/// let result = &a + &b;
522/// assert_eq!(result.shape(), &[4, 5]);
523///
524/// // A      (3d tensor):  5 x 3 x 15
525/// // B      (3d tensor):  5 x 1 x 15
526/// // -------------------------------
527/// // Result (3d tensor):  5 x 3 x 15
528/// let a = rt::arange((225, &device)).into_shape([5, 3, 15]);
529/// let b = rt::arange((75, &device)).into_shape([5, 1, 15]);
530/// let result = &a + &b;
531/// assert_eq!(result.shape(), &[5, 3, 15]);
532///
533/// // A      (3d tensor):  5 x 3 x 15
534/// // B      (2d tensor):  5 x 3
535/// // -------------------------------
536/// // Result (3d tensor):  5 x 3 x 15
537/// let a = rt::arange((225, &device)).into_shape([5, 3, 15]);
538/// let b = rt::arange((15, &device)).into_shape([5, 3]);
539/// let result = &a + &b;
540/// assert_eq!(result.shape(), &[5, 3, 15]);
541///
542/// // A      (3d tensor):  5 x 3 x 15
543/// // B      (2d tensor):  1 x 3
544/// // -------------------------------
545/// // Result (3d tensor):  5 x 3 x 15
546/// let a = rt::arange((225, &device)).into_shape([5, 3, 15]);
547/// let b = rt::arange((3, &device)).into_shape([1, 3]);
548/// let result = &a + &b;
549/// assert_eq!(result.shape(), &[5, 3, 15]);
550/// ```
551///
552/// # Notes of API accordance
553///
554/// - Array-API: `broadcast_to(x, /, shape)` ([`broadcast_to`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.broadcast_to.html))
555/// - NumPy: `numpy.broadcast_to(array, shape, subok=False)` ([`numpy.broadcast_to`](https://numpy.org/doc/stable/reference/generated/numpy.broadcast_to.html))
556/// - ndarray: [`ArrayBase::broadcast`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html#method.broadcast)
557/// - RSTSR: `rt::to_broadcast(tensor, shape)` or `rt::broadcast_to(tensor, shape)`
558///
559/// # Panics
560///
561/// Panics if the tensor cannot be broadcast to the given shape.
562///
563/// For a fallible version, use [`to_broadcast_f`].
564///
565/// # See also
566///
567/// ## Detailed broadcasting rules
568///
569/// - Array-API: [Broadcasting rules](https://data-apis.org/array-api/latest/API_specification/broadcasting.html)
570/// - NumPy: [Broadcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html)
571///
572/// ## Related functions in RSTSR
573///
574/// - [`broadcast_arrays`]: Broadcasts any number of arrays against each other.
575///
576/// ## Variants of this function
577///
578/// - [`to_broadcast`] / [`to_broadcast_f`]: Non-consuming version.
579/// - [`into_broadcast`] / [`into_broadcast_f`]: Consuming version.
580/// - [`broadcast_to`]: Alias for `to_broadcast` (name of Python Array API standard).
581/// - Associated methods on [`TensorAny`]:
582///
583///   - [`TensorAny::to_broadcast`] / [`TensorAny::to_broadcast_f`]
584///   - [`TensorAny::into_broadcast`] / [`TensorAny::into_broadcast_f`]
585///   - [`TensorAny::broadcast_to`]
586pub fn to_broadcast<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>, shape: D2) -> TensorView<'_, T, B, D2>
587where
588    D: DimAPI + DimMaxAPI<D2, Max = D2>,
589    D2: DimAPI,
590    R: DataAPI<Data = B::Raw>,
591    B: DeviceAPI<T>,
592{
593    into_broadcast_f(tensor.view(), shape).rstsr_unwrap()
594}
595
596/// Broadcasts an array to a specified shape.
597///
598/// See also [`to_broadcast`].
599pub fn to_broadcast_f<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>, shape: D2) -> Result<TensorView<'_, T, B, D2>>
600where
601    D: DimAPI + DimMaxAPI<D2, Max = D2>,
602    D2: DimAPI,
603    R: DataAPI<Data = B::Raw>,
604    B: DeviceAPI<T>,
605{
606    into_broadcast_f(tensor.view(), shape)
607}
608
609/// Broadcasts an array to a specified shape.
610///
611/// See also [`to_broadcast`].
612pub fn into_broadcast<R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, shape: D2) -> TensorAny<R, T, B, D2>
613where
614    R: DataAPI<Data = B::Raw>,
615    B: DeviceAPI<T>,
616    D: DimAPI + DimMaxAPI<D2, Max = D2>,
617    D2: DimAPI,
618{
619    into_broadcast_f(tensor, shape).rstsr_unwrap()
620}
621
622pub use into_broadcast as broadcast_into;
623pub use into_broadcast_f as broadcast_into_f;
624pub use to_broadcast as broadcast_to;
625pub use to_broadcast_f as broadcast_to_f;
626
627impl<R, T, B, D> TensorAny<R, T, B, D>
628where
629    R: DataAPI<Data = B::Raw>,
630    B: DeviceAPI<T>,
631    D: DimAPI,
632{
633    /// Broadcasts an array to a specified shape.
634    ///
635    /// See also [`to_broadcast`].
636    pub fn to_broadcast<D2>(&self, shape: D2) -> TensorView<'_, T, B, D2>
637    where
638        D2: DimAPI,
639        D: DimMaxAPI<D2, Max = D2>,
640    {
641        to_broadcast(self, shape)
642    }
643
644    /// Broadcasts an array to a specified shape.
645    ///
646    /// See also [`to_broadcast`].
647    pub fn broadcast_to<D2>(&self, shape: D2) -> TensorView<'_, T, B, D2>
648    where
649        D2: DimAPI,
650        D: DimMaxAPI<D2, Max = D2>,
651    {
652        broadcast_to(self, shape)
653    }
654
655    /// Broadcasts an array to a specified shape.
656    ///
657    /// See also [`to_broadcast`].
658    pub fn to_broadcast_f<D2>(&self, shape: D2) -> Result<TensorView<'_, T, B, D2>>
659    where
660        D2: DimAPI,
661        D: DimMaxAPI<D2, Max = D2>,
662    {
663        to_broadcast_f(self, shape)
664    }
665
666    /// Broadcasts an array to a specified shape.
667    ///
668    /// See also [`to_broadcast`].
669    pub fn into_broadcast<D2>(self, shape: D2) -> TensorAny<R, T, B, D2>
670    where
671        D2: DimAPI,
672        D: DimMaxAPI<D2, Max = D2>,
673    {
674        into_broadcast(self, shape)
675    }
676
677    /// Broadcasts an array to a specified shape.
678    ///
679    /// See also [`to_broadcast`].
680    pub fn into_broadcast_f<D2>(self, shape: D2) -> Result<TensorAny<R, T, B, D2>>
681    where
682        D2: DimAPI,
683        D: DimMaxAPI<D2, Max = D2>,
684    {
685        into_broadcast_f(self, shape)
686    }
687}
688
689/* #endregion */