rstsr_core/tensor/manuplication/broadcast.rs
1use crate::prelude_dev::*;
2
3/* #region broadcast_shapes */
4
5/// Broadcasts shapes against each other and returns the resulting shape.
6///
7/// This function computes the shape that results from broadcasting the given
8/// shapes against one another according to the standard broadcasting rules.
9/// It does not require actual tensor data, only the shapes.
10///
11/// <div class="warning">
12///
13/// **Row/Column Major Notice**
14///
15/// This function behaves differently on default orders ([`RowMajor`] and [`ColMajor`]) of device.
16///
17/// </div>
18///
19/// # Parameters
20///
21/// - `shapes`: A slice of shapes, where each shape is [`IxD`] (`Vec<usize>`).
22/// - `order`: The memory order ([`RowMajor`] or [`ColMajor`]) that determines how broadcasting
23/// rules are applied.
24///
25/// # Returns
26///
27/// - `IxD`: The resulting shape after broadcasting all input shapes together.
28///
29/// # Examples
30///
31/// Broadcast two shapes in row-major order:
32///
33/// ```rust
34/// # use rstsr::prelude::*;
35/// // A (4d array): 8 x 1 x 6 x 1
36/// // B (3d array): 7 x 1 x 5
37/// // ---------------------------------
38/// // Result (4d array): 8 x 7 x 6 x 5
39/// let shape1 = vec![8, 1, 6, 1];
40/// let shape2 = vec![7, 1, 5];
41/// let result = rt::broadcast_shapes(&[shape1, shape2], RowMajor);
42/// println!("{:?}", result);
43/// // [8, 7, 6, 5]
44/// assert_eq!(result, vec![8, 7, 6, 5]);
45/// ```
46///
47/// In col-major order, the broadcasting is applied from the left instead of right:
48///
49/// ```rust
50/// # use rstsr::prelude::*;
51/// // A (4d array): 1 x 6 x 1 x 8
52/// // B (3d array): 5 x 1 x 7
53/// // ---------------------------------
54/// // Result (4d array): 5 x 6 x 7 x 8
55/// let shape1 = vec![1, 6, 1, 8];
56/// let shape2 = vec![5, 1, 7];
57/// let result = rt::broadcast_shapes(&[shape1, shape2], ColMajor);
58/// println!("{:?}", result);
59/// // [5, 6, 7, 8]
60/// assert_eq!(result, vec![5, 6, 7, 8]);
61/// ```
62///
63/// Broadcast multiple shapes:
64///
65/// ```rust
66/// # use rstsr::prelude::*;
67/// // Three shapes: (1,), (3, 1), (3, 4) -> (3, 4)
68/// let shapes = vec![vec![1], vec![3, 1], vec![3, 4]];
69/// let result = rt::broadcast_shapes(&shapes, RowMajor);
70/// println!("{:?}", result);
71/// // [3, 4]
72/// assert_eq!(result, vec![3, 4]);
73/// ```
74///
75/// # Notes of API accordance
76///
77/// - Array-API: `broadcast_shapes(*shapes)` ([`broadcast_shapes`](https://data-apis.org/array-api/2024.12/API_specification/generated/array_api.broadcast_shapes.html))
78/// - NumPy: `numpy.broadcast_shapes(*shapes)` ([`numpy.broadcast_shapes`](https://numpy.org/doc/stable/reference/generated/numpy.broadcast_shapes.html))
79/// - RSTSR: `rt::broadcast_shapes(shapes, order)`
80///
81/// # Panics
82///
83/// Panics if the shapes cannot be broadcast together.
84///
85/// For a fallible version, use [`broadcast_shapes_f`].
86///
87/// # See also
88///
89/// ## Variants of this function
90///
91/// - [`broadcast_shapes_f`]: Fallible version, actual implementation.
92///
93/// ## Related functions in RSTSR
94///
95/// - [`broadcast_arrays`]: Broadcasts any number of arrays against each other.
96/// - [`broadcast_to`]: Broadcasts an array to a specified shape.
97pub fn broadcast_shapes(shapes: &[IxD], order: FlagOrder) -> IxD {
98 broadcast_shapes_f(shapes, order).rstsr_unwrap()
99}
100
101/// Broadcasts shapes against each other and returns the resulting shape.
102///
103/// See also [`broadcast_shapes`].
104pub fn broadcast_shapes_f(shapes: &[IxD], order: FlagOrder) -> Result<IxD> {
105 // fast return if empty or single shape
106 if shapes.is_empty() {
107 return Ok(vec![]);
108 }
109 if shapes.len() == 1 {
110 return Ok(shapes[0].clone());
111 }
112
113 // iteratively broadcast all shapes
114 let mut result = shapes[0].clone();
115 for shape in shapes.iter().skip(1) {
116 let (new_result, _, _) = broadcast_shape(shape, &result, order)?;
117 result = new_result;
118 }
119 Ok(result)
120}
121
122/* #endregion */
123
124/* #region broadcast_arrays */
125
126/// Broadcasts any number of arrays against each other.
127///
128/// <div class="warning">
129///
130/// **Row/Column Major Notice**
131///
132/// This function behaves differently on default orders ([`RowMajor`] and [`ColMajor`]) of device.
133///
134/// </div>
135///
136/// # Parameters
137///
138/// - `tensors`: [`Vec<TensorAny<R, T, B, IxD>>`](TensorAny)
139///
140/// - The tensors to be broadcasted.
141/// - All tensors must be on the same device, and share the same ownerships.
142/// - This function takes ownership of the input tensors. If you want to obtain broadcasted views,
143/// you need to create a new vector of views first.
144/// - This function only accepts dynamic shape tensors ([`IxD`]).
145///
146/// # Returns
147///
148/// - [`Vec<TensorAny<R, T, B, IxD>>`](TensorAny)
149///
150/// - A vector of broadcasted tensors. Each tensor has the same shape after broadcasting.
151/// - The ownership of the underlying data is moved from the input tensors to the output tensors.
152/// - The tensors are typically not contiguous (with zero strides at the broadcasted axes).
153/// Writing values to broadcasted tensors is dangerous, but RSTSR will generally not panic on
154/// this behavior. Perform [`to_contig`] afterwards if requires owned contiguous tensors.
155///
156/// # Examples
157///
158/// The following example demonstrates how to use `broadcast_arrays` to broadcast two tensors:
159///
160/// ```rust
161/// # use rstsr::prelude::*;
162/// # let mut device = DeviceCpu::default();
163/// # device.set_default_order(RowMajor);
164/// let a = rt::asarray((vec![1, 2, 3], &device)).into_shape([3]);
165/// println!("{a}");
166/// // [ 1 2 3]
167/// let b = rt::asarray((vec![4, 5], &device)).into_shape([2, 1]);
168/// println!("{b}");
169/// // [[ 4]
170/// // [ 5]]
171///
172/// let result = rt::broadcast_arrays(vec![a, b]);
173/// println!("broadcasted a:\n{:}", result[0]);
174/// // [[ 1 2 3]
175/// // [ 1 2 3]]
176/// println!("broadcasted b:\n{:}", result[1]);
177/// // [[ 4 4 4]
178/// // [ 5 5 5]]
179/// ```
180///
181/// Please note that the above code only works in [RowMajor].
182///
183/// For [ColMajor] order, the broadcasting will fail, because the broadcasting rules are applied
184/// differently, shapes are incompatible (for col-major, broadcast comparison starts from left
185/// instead of row-major's right):
186///
187/// ```rust,should_panic
188/// # use rstsr::prelude::*;
189/// # let mut device = DeviceCpu::default();
190/// device.set_default_order(ColMajor);
191/// let a = rt::asarray((vec![1, 2, 3], &device)).into_shape([3]);
192/// let b = rt::asarray((vec![4, 5], &device)).into_shape([2, 1]);
193/// let result = rt::broadcast_arrays(vec![a, b]);
194/// ```
195///
196/// You need to make the following changes to let [ColMajor] case work:
197///
198/// ```rust
199/// # use rstsr::prelude::*;
200/// # let mut device = DeviceCpu::default();
201/// device.set_default_order(ColMajor);
202/// let a = rt::asarray((vec![1, 2, 3], &device)).into_shape([1, 3]);
203/// let b = rt::asarray((vec![4, 5], &device)).into_shape([2, 1]);
204/// let result = rt::broadcast_arrays_f(vec![a, b]);
205/// ```
206///
207/// # Notes of API accordance
208///
209/// - Array-API: `broadcast_arrays(*arrays)` ([`broadcast_arrays`](https://data-apis.org/array-api/2024.12/API_specification/generated/array_api.broadcast_arrays.html))
210/// - NumPy: `numpy.broadcast_arrays(*args, subok=False)` ([`numpy.broadcast_arrays`](https://numpy.org/doc/stable/reference/generated/numpy.broadcast_arrays.html))
211/// - RSTSR: `rt::broadcast_arrays(tensors)`
212///
213/// # Panics
214///
215/// Panics if the tensors cannot be broadcast together, or if tensors are on different devices.
216///
217/// For a fallible version, use [`broadcast_arrays_f`].
218///
219/// # See also
220///
221/// ## Related functions in RSTSR
222///
223/// - [`to_broadcast`]: Broadcasts a single array to a specified shape.
224///
225/// ## Variants of this function
226///
227/// - [`broadcast_arrays_f`]: Fallible version, actual implementation.
228pub fn broadcast_arrays<R, T, B>(tensors: Vec<TensorAny<R, T, B, IxD>>) -> Vec<TensorAny<R, T, B, IxD>>
229where
230 R: DataAPI<Data = B::Raw>,
231 B: DeviceAPI<T>,
232{
233 broadcast_arrays_f(tensors).rstsr_unwrap()
234}
235
236/// Broadcasts any number of arrays against each other.
237///
238/// See also [`broadcast_arrays`].
239pub fn broadcast_arrays_f<R, T, B>(tensors: Vec<TensorAny<R, T, B, IxD>>) -> Result<Vec<TensorAny<R, T, B, IxD>>>
240where
241 R: DataAPI<Data = B::Raw>,
242 B: DeviceAPI<T>,
243{
244 // fast return if there is only zero/one tensor
245 if tensors.len() <= 1 {
246 return Ok(tensors);
247 }
248 let device_b = tensors[0].device().clone();
249 let default_order = device_b.default_order();
250 let mut shape_b = tensors[0].shape().clone();
251 for tensor in tensors.iter().skip(1) {
252 rstsr_assert!(device_b.same_device(tensor.device()), DeviceMismatch)?;
253 let shape = tensor.shape();
254 let (shape, _, _) = broadcast_shape(shape, &shape_b, default_order)?;
255 shape_b = shape;
256 }
257 let mut tensors_new = Vec::with_capacity(tensors.len());
258 for tensor in tensors {
259 let tensor = into_broadcast_f(tensor, shape_b.clone())?;
260 tensors_new.push(tensor);
261 }
262 return Ok(tensors_new);
263}
264
265/* #endregion */
266
267/* #region broadcast_to */
268
269/// Broadcasts an array to a specified shape.
270///
271/// See also [`to_broadcast`].
272pub fn into_broadcast_f<R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, shape: D2) -> Result<TensorAny<R, T, B, D2>>
273where
274 R: DataAPI<Data = B::Raw>,
275 B: DeviceAPI<T>,
276 D: DimAPI + DimMaxAPI<D2, Max = D2>,
277 D2: DimAPI,
278{
279 let shape1 = tensor.shape();
280 let shape2 = &shape;
281 let default_order = tensor.device().default_order();
282 let (shape, tp1, _) = broadcast_shape(shape1, shape2, default_order)?;
283 // For broadcast_to, the result shape must match the target shape exactly
284 // This ensures we cannot broadcast from higher dimensions to lower dimensions
285 rstsr_assert_eq!(
286 shape.as_ref(),
287 shape2.as_ref(),
288 InvalidLayout,
289 "Cannot broadcast to a shape with fewer dimensions or incompatible size."
290 )?;
291 let (storage, layout) = tensor.into_raw_parts();
292 let layout = update_layout_by_shape(&layout, &shape, &tp1, default_order)?;
293 unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
294}
295
296/// Broadcasts an array to a specified shape.
297///
298/// <div class="warning">
299///
300/// **Row/Column Major Notice**
301///
302/// This function behaves differently on default orders ([`RowMajor`] and [`ColMajor`]) of device.
303///
304/// </div>
305///
306/// # Parameters
307///
308/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
309///
310/// - The input tensor to be broadcasted.
311/// - Note on variant [`into_broadcast`]: This takes ownership [`Tensor<R, T, B, D>`] of input
312/// tensor, and will not perform change to underlying data, only layout changes.
313///
314/// - `shape`: impl [`DimAPI`]
315///
316/// - The shape of the desired output tensor after broadcasting.
317/// - Please note [`IxD`] (`Vec<usize>`) and [`Ix<N>`] (`[usize; N]`) behaves differently here.
318/// [`IxD`] will give dynamic shape tensor, while [`Ix<N>`] will give static shape tensor.
319/// Unless you are handling static-shape tensors, always use [`IxD`] (`Vec<usize>`) to specify
320/// shape.
321///
322/// # Returns
323///
324/// - [`TensorView<'_, T, B, D2>`]
325///
326/// - A readonly view on the original tensor with the given shape. It is typically not contiguous
327/// (perform [`to_contig`] afterwards if you require contiguous owned tensors).
328/// - Furthermore, more than one element of a broadcasted tensor may refer to a single memory
329/// location (zero strides at the broadcasted axes). Writing values to broadcasted tensors is
330/// dangerous, but RSTSR will generally not panic on this behavior.
331/// - If you want to convert the tensor itself (taking the ownership instead of returning view),
332/// use [`into_broadcast`] instead.
333///
334/// # Examples
335///
336/// The following example demonstrates how to use `to_broadcast` to broadcast a 1-D tensor
337/// (3-element vector) to a 2-D tensor (2x3 matrix) by repeating the original data along a new axis.
338///
339/// ```rust
340/// # use rstsr::prelude::*;
341/// # let mut device = DeviceCpu::default();
342/// # device.set_default_order(RowMajor);
343/// // broadcast (3, ) -> (2, 3) in row-major:
344/// let a = rt::tensor_from_nested!([1, 2, 3], &device);
345/// let result = a.to_broadcast(vec![2, 3]);
346/// println!("{result}");
347/// // [[ 1 2 3]
348/// // [ 1 2 3]]
349/// ```
350///
351/// Please note the above example is only working in RowMajor order. In ColMajor order, the
352/// broadcasting will be done along the other axis:
353///
354/// ```rust,should_panic
355/// # use rstsr::prelude::*;
356/// # let mut device = DeviceCpu::default();
357/// device.set_default_order(ColMajor);
358/// let a = rt::tensor_from_nested!([1, 2, 3], &device);
359/// // in col-major, broadcast (3, ) -> (2, 3) will panic:
360/// let result = a.to_broadcast(vec![2, 3]);
361/// ```
362///
363/// The correct shape for col-major is `(3, 2)` instead:
364///
365/// ```rust
366/// # use rstsr::prelude::*;
367/// # let mut device = DeviceCpu::default();
368/// device.set_default_order(ColMajor);
369/// let a = rt::tensor_from_nested!([1, 2, 3], &device);
370/// // broadcast (3, ) -> (3, 2) in col-major:
371/// let result = a.to_broadcast(vec![3, 2]);
372/// println!("{result}");
373/// // [[ 1 1]
374/// // [ 2 2]
375/// // [ 3 3]]
376/// ```
377///
378/// # Tips on common compilation errors
379///
380/// You may encounter compilation errors related to shape types like:
381///
382/// ```compile_fail
383/// # use rstsr::prelude::*;
384/// # let mut device = DeviceCpu::default();
385/// # device.set_default_order(RowMajor);
386/// let input_array: Tensor<i32, _> = rt::asarray((0, &device));
387/// let result = rt::broadcast_to(&input_array, [3]);
388/// ```
389///
390/// ```text
391/// 33 | let result = rt::broadcast_to(&input_array, [3]);
392/// | ---------------- ^^^^^^^^^^^^ expected `[usize; 1]`, found `Vec<usize>`
393/// | |
394/// | required by a bound introduced by this call
395/// |
396/// = note: expected array `[usize; 1]`
397/// found struct `std::vec::Vec<usize>`
398/// note: required by a bound in `rstsr_core::prelude_dev::to_broadcast`
399/// --> rstsr-core/src/tensor/manuplication/broadcast.rs:427:31
400/// |
401/// 425 | pub fn to_broadcast<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>, shape: D2) -> TensorView<'_, T, B, D2>
402/// | ------------ required by a bound in this function
403/// 426 | where
404/// 427 | D: DimAPI + DimMaxAPI<D2, Max = D2>,
405/// | ^^^^^^^^ required by this bound in `to_broadcast`
406/// ```
407///
408/// To resolve this, you are suggested to use [`IxD`] (`Vec<usize>`) to specify the shape, instead
409/// of [`Ix<N>`] (`[usize; N]`), unless you are an advanced API caller to handle static-shape
410/// tensors.
411///
412/// ```rust
413/// # use rstsr::prelude::*;
414/// # let mut device = DeviceCpu::default();
415/// # device.set_default_order(RowMajor);
416/// let input_array: Tensor<i32, _> = rt::asarray((0, &device));
417/// let result = rt::broadcast_to(&input_array, vec![3]);
418/// ```
419///
420/// # Elaborated examples
421///
422/// ## Broadcasting behavior (in row-major)
423///
424/// This example does not directly call this function `to_broadcast`, but demonstrates the
425/// broadcasting behavior.
426///
427/// ```rust
428/// use rstsr::prelude::*;
429/// let mut device = DeviceCpu::default();
430/// device.set_default_order(RowMajor);
431///
432/// // A (4d tensor): 8 x 1 x 6 x 1
433/// // B (3d tensor): 7 x 1 x 5
434/// // ----------------------------------
435/// // Result (4d tensor): 8 x 7 x 6 x 5
436/// let a = rt::arange((48, &device)).into_shape([8, 1, 6, 1]);
437/// let b = rt::arange((35, &device)).into_shape([7, 1, 5]);
438/// let result = &a + &b;
439/// assert_eq!(result.shape(), &[8, 7, 6, 5]);
440///
441/// // A (2d tensor): 5 x 4
442/// // B (1d tensor): 1
443/// // --------------------------
444/// // Result (2d tensor): 5 x 4
445/// let a = rt::arange((20, &device)).into_shape([5, 4]);
446/// let b = rt::arange((1, &device)).into_shape([1]);
447/// let result = &a + &b;
448/// assert_eq!(result.shape(), &[5, 4]);
449///
450/// // A (2d tensor): 5 x 4
451/// // B (1d tensor): 4
452/// // --------------------------
453/// // Result (2d tensor): 5 x 4
454/// let a = rt::arange((20, &device)).into_shape([5, 4]);
455/// let b = rt::arange((4, &device)).into_shape([4]);
456/// let result = &a + &b;
457/// assert_eq!(result.shape(), &[5, 4]);
458///
459/// // A (3d tensor): 15 x 3 x 5
460/// // B (3d tensor): 15 x 1 x 5
461/// // -------------------------------
462/// // Result (3d tensor): 15 x 3 x 5
463/// let a = rt::arange((225, &device)).into_shape([15, 3, 5]);
464/// let b = rt::arange((75, &device)).into_shape([15, 1, 5]);
465/// let result = &a + &b;
466/// assert_eq!(result.shape(), &[15, 3, 5]);
467///
468/// // A (3d tensor): 15 x 3 x 5
469/// // B (2d tensor): 3 x 5
470/// // -------------------------------
471/// // Result (3d tensor): 15 x 3 x 5
472/// let a = rt::arange((225, &device)).into_shape([15, 3, 5]);
473/// let b = rt::arange((15, &device)).into_shape([3, 5]);
474/// let result = &a + &b;
475/// assert_eq!(result.shape(), &[15, 3, 5]);
476///
477/// // A (3d tensor): 15 x 3 x 5
478/// // B (2d tensor): 3 x 1
479/// // -------------------------------
480/// // Result (3d tensor): 15 x 3 x 5
481/// let a = rt::arange((225, &device)).into_shape([15, 3, 5]);
482/// let b = rt::arange((3, &device)).into_shape([3, 1]);
483/// let result = &a + &b;
484/// assert_eq!(result.shape(), &[15, 3, 5]);
485/// ```
486///
487/// ## Broadcasting behavior (in col-major)
488///
489/// This example does not directly call this function `to_broadcast`, but demonstrates the
490/// broadcasting behavior.
491///
492/// ```rust
493/// use rstsr::prelude::*;
494/// let mut device = DeviceCpu::default();
495/// device.set_default_order(ColMajor);
496///
497/// // A (4d tensor): 1 x 6 x 1 x 8
498/// // B (3d tensor): 5 x 1 x 7
499/// // ----------------------------------
500/// // Result (4d tensor): 5 x 6 x 7 x 8
501/// let a = rt::arange((48, &device)).into_shape([1, 6, 1, 8]);
502/// let b = rt::arange((35, &device)).into_shape([5, 1, 7]);
503/// let result = &a + &b;
504/// assert_eq!(result.shape(), &[5, 6, 7, 8]);
505///
506/// // A (2d tensor): 4 x 5
507/// // B (1d tensor): 1
508/// // --------------------------
509/// // Result (2d tensor): 4 x 5
510/// let a = rt::arange((20, &device)).into_shape([4, 5]);
511/// let b = rt::arange((1, &device)).into_shape([1]);
512/// let result = &a + &b;
513/// assert_eq!(result.shape(), &[4, 5]);
514///
515/// // A (2d tensor): 4 x 5
516/// // B (1d tensor): 4
517/// // --------------------------
518/// // Result (2d tensor): 4 x 5
519/// let a = rt::arange((20, &device)).into_shape([4, 5]);
520/// let b = rt::arange((4, &device)).into_shape([4]);
521/// let result = &a + &b;
522/// assert_eq!(result.shape(), &[4, 5]);
523///
524/// // A (3d tensor): 5 x 3 x 15
525/// // B (3d tensor): 5 x 1 x 15
526/// // -------------------------------
527/// // Result (3d tensor): 5 x 3 x 15
528/// let a = rt::arange((225, &device)).into_shape([5, 3, 15]);
529/// let b = rt::arange((75, &device)).into_shape([5, 1, 15]);
530/// let result = &a + &b;
531/// assert_eq!(result.shape(), &[5, 3, 15]);
532///
533/// // A (3d tensor): 5 x 3 x 15
534/// // B (2d tensor): 5 x 3
535/// // -------------------------------
536/// // Result (3d tensor): 5 x 3 x 15
537/// let a = rt::arange((225, &device)).into_shape([5, 3, 15]);
538/// let b = rt::arange((15, &device)).into_shape([5, 3]);
539/// let result = &a + &b;
540/// assert_eq!(result.shape(), &[5, 3, 15]);
541///
542/// // A (3d tensor): 5 x 3 x 15
543/// // B (2d tensor): 1 x 3
544/// // -------------------------------
545/// // Result (3d tensor): 5 x 3 x 15
546/// let a = rt::arange((225, &device)).into_shape([5, 3, 15]);
547/// let b = rt::arange((3, &device)).into_shape([1, 3]);
548/// let result = &a + &b;
549/// assert_eq!(result.shape(), &[5, 3, 15]);
550/// ```
551///
552/// # Notes of API accordance
553///
554/// - Array-API: `broadcast_to(x, /, shape)` ([`broadcast_to`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.broadcast_to.html))
555/// - NumPy: `numpy.broadcast_to(array, shape, subok=False)` ([`numpy.broadcast_to`](https://numpy.org/doc/stable/reference/generated/numpy.broadcast_to.html))
556/// - ndarray: [`ArrayBase::broadcast`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html#method.broadcast)
557/// - RSTSR: `rt::to_broadcast(tensor, shape)` or `rt::broadcast_to(tensor, shape)`
558///
559/// # Panics
560///
561/// Panics if the tensor cannot be broadcast to the given shape.
562///
563/// For a fallible version, use [`to_broadcast_f`].
564///
565/// # See also
566///
567/// ## Detailed broadcasting rules
568///
569/// - Array-API: [Broadcasting rules](https://data-apis.org/array-api/latest/API_specification/broadcasting.html)
570/// - NumPy: [Broadcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html)
571///
572/// ## Related functions in RSTSR
573///
574/// - [`broadcast_arrays`]: Broadcasts any number of arrays against each other.
575///
576/// ## Variants of this function
577///
578/// - [`to_broadcast`] / [`to_broadcast_f`]: Non-consuming version.
579/// - [`into_broadcast`] / [`into_broadcast_f`]: Consuming version.
580/// - [`broadcast_to`]: Alias for `to_broadcast` (name of Python Array API standard).
581/// - Associated methods on [`TensorAny`]:
582///
583/// - [`TensorAny::to_broadcast`] / [`TensorAny::to_broadcast_f`]
584/// - [`TensorAny::into_broadcast`] / [`TensorAny::into_broadcast_f`]
585/// - [`TensorAny::broadcast_to`]
586pub fn to_broadcast<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>, shape: D2) -> TensorView<'_, T, B, D2>
587where
588 D: DimAPI + DimMaxAPI<D2, Max = D2>,
589 D2: DimAPI,
590 R: DataAPI<Data = B::Raw>,
591 B: DeviceAPI<T>,
592{
593 into_broadcast_f(tensor.view(), shape).rstsr_unwrap()
594}
595
596/// Broadcasts an array to a specified shape.
597///
598/// See also [`to_broadcast`].
599pub fn to_broadcast_f<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>, shape: D2) -> Result<TensorView<'_, T, B, D2>>
600where
601 D: DimAPI + DimMaxAPI<D2, Max = D2>,
602 D2: DimAPI,
603 R: DataAPI<Data = B::Raw>,
604 B: DeviceAPI<T>,
605{
606 into_broadcast_f(tensor.view(), shape)
607}
608
609/// Broadcasts an array to a specified shape.
610///
611/// See also [`to_broadcast`].
612pub fn into_broadcast<R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, shape: D2) -> TensorAny<R, T, B, D2>
613where
614 R: DataAPI<Data = B::Raw>,
615 B: DeviceAPI<T>,
616 D: DimAPI + DimMaxAPI<D2, Max = D2>,
617 D2: DimAPI,
618{
619 into_broadcast_f(tensor, shape).rstsr_unwrap()
620}
621
622pub use into_broadcast as broadcast_into;
623pub use into_broadcast_f as broadcast_into_f;
624pub use to_broadcast as broadcast_to;
625pub use to_broadcast_f as broadcast_to_f;
626
627impl<R, T, B, D> TensorAny<R, T, B, D>
628where
629 R: DataAPI<Data = B::Raw>,
630 B: DeviceAPI<T>,
631 D: DimAPI,
632{
633 /// Broadcasts an array to a specified shape.
634 ///
635 /// See also [`to_broadcast`].
636 pub fn to_broadcast<D2>(&self, shape: D2) -> TensorView<'_, T, B, D2>
637 where
638 D2: DimAPI,
639 D: DimMaxAPI<D2, Max = D2>,
640 {
641 to_broadcast(self, shape)
642 }
643
644 /// Broadcasts an array to a specified shape.
645 ///
646 /// See also [`to_broadcast`].
647 pub fn broadcast_to<D2>(&self, shape: D2) -> TensorView<'_, T, B, D2>
648 where
649 D2: DimAPI,
650 D: DimMaxAPI<D2, Max = D2>,
651 {
652 broadcast_to(self, shape)
653 }
654
655 /// Broadcasts an array to a specified shape.
656 ///
657 /// See also [`to_broadcast`].
658 pub fn to_broadcast_f<D2>(&self, shape: D2) -> Result<TensorView<'_, T, B, D2>>
659 where
660 D2: DimAPI,
661 D: DimMaxAPI<D2, Max = D2>,
662 {
663 to_broadcast_f(self, shape)
664 }
665
666 /// Broadcasts an array to a specified shape.
667 ///
668 /// See also [`to_broadcast`].
669 pub fn into_broadcast<D2>(self, shape: D2) -> TensorAny<R, T, B, D2>
670 where
671 D2: DimAPI,
672 D: DimMaxAPI<D2, Max = D2>,
673 {
674 into_broadcast(self, shape)
675 }
676
677 /// Broadcasts an array to a specified shape.
678 ///
679 /// See also [`to_broadcast`].
680 pub fn into_broadcast_f<D2>(self, shape: D2) -> Result<TensorAny<R, T, B, D2>>
681 where
682 D2: DimAPI,
683 D: DimMaxAPI<D2, Max = D2>,
684 {
685 into_broadcast_f(self, shape)
686 }
687}
688
689/* #endregion */