rstsr_core/tensor/manuplication/
reshape.rs

1use crate::prelude_dev::*;
2
3/* #region reshape */
4
5/// Reshapes the given tensor to the specified shape.
6///
7/// # See also
8///
9/// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
10/// examples.
11pub fn change_shape_f<'a, I, R, T, B, D>(tensor: TensorAny<R, T, B, D>, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
12where
13    I: TryInto<AxesIndex<isize>, Error = Error>,
14    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
15    D: DimAPI,
16    B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
17{
18    // own shape, this is cheap operation
19    let shape_new = reshape_substitute_negatives(shape.try_into()?.as_ref(), tensor.size())?;
20    let default_order = tensor.device().default_order();
21    if let Some(layout_new) = layout_reshapeable(&tensor.layout().to_dim()?, &shape_new, default_order)? {
22        // shape does not need to be changed
23        let (storage, _) = tensor.into_raw_parts();
24        let layout = layout_new.into_dim::<IxD>()?;
25        return unsafe { Ok(TensorBase::new_unchecked(storage, layout).into_cow()) };
26    } else {
27        // clone underlying data by assign_arbitary
28        let (storage, layout) = tensor.into_raw_parts();
29        let device = storage.device();
30        let layout_new = match default_order {
31            RowMajor => shape_new.new_c_contig(None),
32            ColMajor => shape_new.new_f_contig(None),
33        };
34        let mut storage_new = device.uninit_impl(layout_new.size())?;
35        device.assign_arbitary_uninit(storage_new.raw_mut(), &layout_new, storage.raw(), &layout)?;
36        let storage_new = unsafe { B::assume_init_impl(storage_new)? };
37        return unsafe { Ok(TensorBase::new_unchecked(storage_new, layout_new).into_cow()) };
38    }
39}
40
41/// Reshapes the given tensor to the specified shape.
42///
43/// This function is not intended to be used by usual users. Please consider using
44/// [`reshape`] (take reference of tensor) or [`into_shape`] (take ownership of tensor)
45/// instead.
46///
47/// <div class="warning">
48///
49/// **Row/Column Major Notice**
50///
51/// This function behaves differently on default orders ([`RowMajor`] and [`ColMajor`]) of device.
52///
53/// </div>
54///
55/// # Parameters
56///
57/// - `tensor`: [`TensorAny<R, T, B, D>`]
58///
59///   - The input tensor to be reshaped.
60///   - Ownership of input tensor is taken.
61///
62/// - `axes`: TryInto [`AxesIndex<isize>`]
63///
64///   - Position in the expanded axes where the new axis (or axes) is placed.
65///   - Can be a single integer, or a list/tuple of integers.
66///   - Negative values are supported and indicate counting dimensions from the back.
67///
68/// # Returns
69///
70/// - [`TensorCow<'a, T, B, IxD>`](TensorCow)
71///
72///   - The reshaped tensor.
73///   - This function will try to avoid data cloning if possible.
74///
75///     - If layout-compatible, depending on whether the input tensor is owned or other cases,
76///       either a view or owned tensor will be returned.
77///     - If layout-not-compatible, an owned tensor will be returned, cloning the data.
78///     - Cow (Clone-on-Write) semantics is used for representing either view or owned tensor.
79///
80/// This function is different to [`reshape`], in that it takes ownership of the input
81/// tensor.
82///
83/// This function is also different to [`into_shape`], in that it may return a view, if the input
84/// tensor also have the ownership of tensor view, and the layout is compatible.
85///
86/// # See also
87///
88/// Refer to [`reshape`] for more details and examples.
89pub fn change_shape<'a, I, R, T, B, D>(tensor: TensorAny<R, T, B, D>, shape: I) -> TensorCow<'a, T, B, IxD>
90where
91    I: TryInto<AxesIndex<isize>, Error = Error>,
92    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
93    D: DimAPI,
94    B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
95{
96    change_shape_f(tensor, shape).rstsr_unwrap()
97}
98
99/// Reshapes the given tensor to the specified shape.
100///
101/// # See also
102///
103/// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
104/// examples.
105pub fn into_shape_f<'a, I, R, T, B, D>(tensor: TensorAny<R, T, B, D>, shape: I) -> Result<Tensor<T, B, IxD>>
106where
107    I: TryInto<AxesIndex<isize>, Error = Error>,
108    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
109    D: DimAPI,
110    T: Clone,
111    B: DeviceAPI<T>
112        + DeviceRawAPI<MaybeUninit<T>>
113        + DeviceCreationAnyAPI<T>
114        + OpAssignArbitaryAPI<T, IxD, D>
115        + OpAssignAPI<T, IxD>,
116    <B as DeviceRawAPI<T>>::Raw: Clone + 'a,
117{
118    change_shape_f(tensor, shape).map(|v| v.into_owned())
119}
120
121/// Reshapes the given tensor to the specified shape.
122///
123/// <div class="warning">
124///
125/// **Row/Column Major Notice**
126///
127/// This function behaves differently on default orders ([`RowMajor`] and [`ColMajor`]) of device.
128///
129/// </div>
130///
131/// # Parameters
132///
133/// - `tensor`: [`TensorAny<R, T, B, D>`]
134///
135///   - The input tensor to be reshaped.
136///   - Ownership of input tensor is taken.
137///
138/// - `axes`: TryInto [`AxesIndex<isize>`]
139///
140///   - Position in the expanded axes where the new axis (or axes) is placed.
141///   - Can be a single integer, or a list/tuple of integers.
142///   - Negative values are supported and indicate counting dimensions from the back.
143///
144/// # Returns
145///
146/// - [`Tensor<T, B, IxD>`]
147///
148///   - The reshaped tensor.
149///   - This function will try to avoid data cloning if possible, but with strict conditions:
150///
151///     - Layout-compatible after reshaping;
152///     - Input tensor owns the underlying data (i.e., not a view);
153///     - The input tensor is compact in memory (i.e., the underlying data does not have redundant
154///       elements; size of tensor exactly matches the length of underlying data).
155///
156/// This function is different to [`change_shape`](change_shape()) and [`reshape`], in
157/// that it takes ownership of the input tensor, and always returns an owned tensor.
158///
159/// # Examples
160///
161/// ```rust
162/// use rstsr::prelude::*;
163/// let a = rt::arange(6).into_shape([2, 3]);
164/// ```
165///
166/// # Elaborated examples
167///
168/// Here is some showcases that demonstrate when data cloning happens or not. All examples are
169/// row-major.
170///
171/// A first case is a tensor that is not fully contiguous (containing negative strides), but the
172/// tensor is compact (size of tensor is the same to the length of underlying data). In this case,
173/// if the new shape is compatible, no data cloning happens:
174///
175/// ```rust
176/// # use rstsr::prelude::*;
177/// # let mut device = DeviceCpu::default();
178/// # device.set_default_order(RowMajor);
179/// // shape: (4, 6, 9), stride: (-54, 9, 1), not c-contiguous
180/// // contiguous situation: (4, [6, 9]); the first dimension is reversed
181/// let a = rt::arange((216, &device)).into_shape([4, 6, 9]).into_flip(0);
182/// let a_ptr = a.raw().as_ptr();
183/// let b = a.into_shape([4, 54]);
184/// let b_ptr = b.raw().as_ptr();
185/// assert_eq!(a_ptr, b_ptr); // contiguous dims merged, no data clone happened
186/// ```
187///
188/// However, if the new shape is not compatible, data cloning will happen:
189///
190/// ```rust
191/// # use rstsr::prelude::*;
192/// # let mut device = DeviceCpu::default();
193/// # device.set_default_order(RowMajor);
194/// // shape: (4, 6, 9), stride: (-54, 9, 1), not c-contiguous
195/// // contiguous situation: (4, [6, 9]); the first dimension is reversed
196/// let a = rt::arange((216, &device)).into_shape([4, 6, 9]).into_flip(0);
197/// let a_ptr = a.raw().as_ptr();
198/// let b = a.into_shape([24, 9]);
199/// let b_ptr = b.raw().as_ptr();
200/// assert_ne!(a_ptr, b_ptr); // layout not compatible, data clone happened
201/// ```
202///
203/// Another case is a tensor that is not compact (size of tensor is less than the length of
204/// underlying data). In this case, even if the new shape is compatible, data cloning will happen:
205///
206/// ```rust
207/// # use rstsr::prelude::*;
208/// # let mut device = DeviceCpu::default();
209/// # device.set_default_order(RowMajor);
210/// // shape: (4, 6, 9), stride: (72, 9, 1), not c-contiguous
211/// // contiguous situation: (4, [6, 9]), or say the last two dimensions are contiguous
212/// let a = rt::arange((288, &device)).into_shape([4, 8, 9]).into_slice((.., 0..6, ..));
213/// let a_ptr = a.raw().as_ptr();
214/// let b = a.into_shape([4, 54]);
215/// let b_ptr = b.raw().as_ptr();
216/// assert_ne!(a_ptr, b_ptr); // layout-compatible, but input tensor is not compact (216 < 288)
217/// ```
218///
219/// # See also
220///
221/// Refer to [`reshape`] for more details and examples.
222pub fn into_shape<'a, I, R, T, B, D>(tensor: TensorAny<R, T, B, D>, shape: I) -> Tensor<T, B, IxD>
223where
224    I: TryInto<AxesIndex<isize>, Error = Error>,
225    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
226    D: DimAPI,
227    T: Clone,
228    B: DeviceAPI<T>
229        + DeviceRawAPI<MaybeUninit<T>>
230        + DeviceCreationAnyAPI<T>
231        + OpAssignArbitaryAPI<T, IxD, D>
232        + OpAssignAPI<T, IxD>,
233    <B as DeviceRawAPI<T>>::Raw: Clone + 'a,
234{
235    into_shape_f(tensor, shape).rstsr_unwrap()
236}
237
238/// Reshapes the given tensor to the specified shape.
239///
240/// # See also
241///
242/// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
243/// examples.
244pub fn to_shape_f<'a, I, R, T, B, D>(tensor: &'a TensorAny<R, T, B, D>, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
245where
246    I: TryInto<AxesIndex<isize>, Error = Error>,
247    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
248    D: DimAPI,
249    B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
250{
251    change_shape_f(tensor.view(), shape)
252}
253
254/// Reshapes the given tensor to the specified shape.
255///
256/// # See also
257///
258/// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
259/// examples.
260pub fn to_shape<'a, I, R, T, B, D>(tensor: &'a TensorAny<R, T, B, D>, shape: I) -> TensorCow<'a, T, B, IxD>
261where
262    I: TryInto<AxesIndex<isize>, Error = Error>,
263    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
264    D: DimAPI,
265    B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
266{
267    to_shape_f(tensor, shape).rstsr_unwrap()
268}
269
270/// Reshapes the given tensor to the specified shape.
271///
272/// # See also
273///
274/// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
275/// examples.
276pub fn reshape_f<'a, I, R, T, B, D>(tensor: &'a TensorAny<R, T, B, D>, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
277where
278    I: TryInto<AxesIndex<isize>, Error = Error>,
279    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
280    D: DimAPI,
281    B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
282{
283    to_shape_f(tensor, shape)
284}
285
286/// Reshapes the given tensor to the specified shape.
287///
288/// <div class="warning">
289///
290/// **Row/Column Major Notice**
291///
292/// This function behaves differently on default orders ([`RowMajor`] and [`ColMajor`]) of device.
293///
294/// </div>
295///
296/// # Parameters
297///
298/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
299///
300///   - The input tensor to be reshaped.
301///
302/// - `axes`: TryInto [`AxesIndex<isize>`]
303///
304///   - Position in the expanded axes where the new axis (or axes) is placed.
305///   - Can be a single integer, or a list/tuple of integers.
306///   - Negative values are supported and indicate counting dimensions from the back.
307///
308/// # Returns
309///
310/// - [`TensorCow<'a, T, B, IxD>`](TensorCow)
311///
312///   - The reshaped tensor.
313///   - This function will try to avoid data cloning if possible.
314///
315///     - If layout-compatible, a view will be returned.
316///     - If shape-not-compatible, an owned tensor will be returned, cloning the data.
317///     - Cow (Clone-on-Write) semantics is used for representing either view or owned tensor.
318///
319/// # Examples
320///
321/// In row-major order, to reshape a vector of (6, ) to a matrix of (2, 3):
322/// ```rust
323/// use rstsr::prelude::*;
324/// let mut device = DeviceCpu::default();
325/// device.set_default_order(RowMajor);
326///
327/// let a = rt::arange((6, &device));
328/// let a_reshaped = a.reshape([2, 3]);
329/// let a_expected = rt::tensor_from_nested!(
330///     [[0, 1, 2], [3, 4, 5]],
331///     &device);
332/// assert!(rt::allclose(&a_reshaped, &a_expected, None));
333/// ```
334///
335/// You can also use negative dimension, where -1 means "infer this dimension":
336///
337/// ```rust
338/// # use rstsr::prelude::*;
339/// # let mut device = DeviceCpu::default();
340/// # device.set_default_order(RowMajor);
341/// #
342/// // in this case, unspecified axes length is inferred as 6 / 3 = 2
343/// let a_reshaped = a.reshape([3, -1]);
344/// let a_expected = rt::tensor_from_nested!(
345///     [[0, 1], [2, 3], [4, 5]],
346///     &device);
347/// assert!(rt::allclose(&a_reshaped, &a_expected, None));
348/// ```
349///
350/// # Ownership Semantics between [`reshape`], [`into_shape`] and [`change_shape`]
351///
352/// [`into_shape`] and [`change_shape`] take ownership of the input tensor. They are important
353/// variants to this function [`reshape`].
354///
355/// | Function | Input Ownership | Output Ownership | Cloning Condition |
356/// |--|--|--|--|
357/// | [`reshape`] | Borrowed <br> [`&TensorAny`](TensorAny) | View <br> [`TensorCow`] with [`DataCow::Ref`] | not cloned (layout-compatible) |
358/// | | | Owned <br> [`TensorCow`] with [`DataCow::Owned`] | cloned (layout-not-compatible) |
359/// | [`into_shape`] | Owned <br> [`Tensor`] | Owned <br> [`Tensor`] | not cloned (layout-compatible, input tensor owns data, input tensor is compact) |
360/// | | | Owned <br> [`Tensor`] | cloned (otherwise) |
361/// | | Otherwise <br> [`TensorAny`] | Owned <br> [`Tensor`] | cloned (always) |
362/// | [`change_shape`] | Owned <br> [`Tensor`] | Owned <br> [`TensorCow`] with [`DataCow::Owned`] | not cloned (layout-compatible, input tensor owns data, input tensor is compact) |
363/// | | | Owned <br> [`TensorCow`] with [`DataCow::Owned`] | cloned (otherwise) |
364/// | | Otherwise <br> [`TensorAny`] | View <br> [`TensorCow`] with [`DataCow::Ref`] | not cloned (layout-compatible) |
365/// | | | Owned <br> [`TensorCow`] with [`DataCow::Owned`] | cloned (layout-not-compatible) |
366///
367/// # Tips on common compilation errors
368///
369/// You may encounter ownership problem when you try to assign a reshaped tensor like this:
370///
371/// ```rust,should_panic
372/// # use rstsr::prelude::*;
373/// # let mut device = DeviceCpu::default();
374/// # device.set_default_order(RowMajor);
375/// let a = rt::arange((6, &device)).reshape([2, 3]);
376/// ```
377///
378/// The compiler may give an error like:
379///
380/// ```text
381/// 704 |    let a = rt::arange((6, &device)).reshape([2, 3]);
382///     |            ^^^^^^^^^^^^^^^^^^^^^^^^                - temporary value is freed at the end of this statement
383///     |            |
384///     |            creates a temporary value which is freed while still in use
385/// 705 |    println!("a: {:?}", a);
386///     |                        - borrow later used here
387///     |
388/// help: consider using a `let` binding to create a longer lived value
389///     |
390/// 704 ~    let binding = rt::arange((6, &device));
391/// 705 ~    let a = binding.reshape([2, 3]);
392///     |
393/// ```
394///
395/// The suggestion by compiler is correct. However, you have another simpler way to solve this
396/// problem by using [`into_shape`] variant that takes ownership:
397///
398/// ```rust
399/// # use rstsr::prelude::*;
400/// # let mut device = DeviceCpu::default();
401/// # device.set_default_order(RowMajor);
402/// let a = rt::arange((6, &device)).into_shape([2, 3]);
403/// ```
404///
405/// # Notes of accordance
406///
407/// ## To Python Array API Standard
408///
409/// This function corresponds to Python Array API Standard:
410/// [`reshape`](https://data-apis.org/array-api/2024.12/API_specification/generated/array_api.reshape.html).
411///
412/// However, please note that this function does not implement the optional keyword `copy` as in
413/// the standard. `copy` keyword in the standard specifies whether to return a copy of the array
414/// data when the requested shape is not compatible with the original shape.
415///
416/// This function implements `copy = None` behavior in the standard, which means that it will return
417/// a view if possible, and return an owned tensor (cloning the data) if necessary.
418///
419/// To achieve similar functionality of optional keyword `copy`,
420///
421/// - For `copy = True` case, you are recommended to
422///
423///   - use [`into_shape`], which always returns an owned tensor, cloning the data if necessary. But
424///     note the necessity of cloning depends on the layout, and RSTSR may still not explicitly
425///     perform cloning.
426///   - use [`to_contig`], which always returns a contiguous owned tensor, cloning the data if
427///     necessary. But note that this function may still not explicitly perform cloning if the
428///     tensor is already contiguous.
429///   - use [`to_owned`](TensorAny::to_owned) as associated method to give an owned tensor, which
430///     always perform cloning.
431///
432/// - For `copy = False` case, you are recommended to
433///
434///   - use utility function [`layout_reshapeable`] to check whether the layout is compatible with
435///     the new shape.
436///
437/// ## To NumPy
438///
439/// This function corresponds to NumPy:
440/// [`reshape`](https://numpy.org/doc/stable/reference/generated/numpy.reshape.html).
441///
442/// However, please note that this function does not implement the optional keyword `order` as in
443/// the NumPy version. `order` keyword in NumPy specifies the iteration order to read elements from
444/// the tensor to-be-reshaped.
445///
446/// This function uses the device's default order to determine the layout of the reshaped tensor.
447/// You can check the device's current default order by
448/// [`device.default_order`](DeviceBaseAPI::default_order). Also see the elaborated examples below.
449///
450/// To change the device's default order, you can use
451///
452/// - [`device.set_default_order`](DeviceBaseAPI::set_default_order) to set the default order of a
453///   device instance, and then
454/// - [`change_device`](TensorDeviceChangeAPI::change_device) or
455///   [`into_device`](TensorDeviceChangeAPI::into_device) or
456///   [`to_device`](TensorDeviceChangeAPI::to_device) to change the tensor's device to the modified
457///   device. Choose the appropriate method depending on the desired ownership semantics.
458///
459/// # Elaborated examples
460///
461/// ## Difference between [RowMajor] and [ColMajor]
462///
463/// Tensor can be uniquely iterated (into a 1-dimension vector), for either row-major or
464/// column-major order.
465///
466/// **Reshape operation does not change the iterated sequence of a tensor**, by definition. In other
467/// words, the following code always holds true:
468///
469/// ```rust
470/// # use rstsr::prelude::*;
471/// # let mut device = DeviceCpu::default();
472/// # device.set_default_order(ColMajor);
473/// # let a = rt::tensor_from_nested!([[0, 1, 2], [3, 4, 5]], &device);
474/// # let b = a.reshape([3, 2]);
475/// // note iteration order of associated method `iter` depends on `device.default_order()`
476///
477/// // let b = a.reshape(... SOME SHAPE ...);
478/// let a_vec = a.iter().collect::<Vec<_>>();
479/// let b_vec = b.iter().collect::<Vec<_>>();
480/// assert_eq!(a_vec, b_vec); // iterated sequence is the same
481/// ```
482///
483/// For example, in row-major order, reshape a matrix of (2, 3) to (3, 2):
484///
485/// ```rust
486/// # use rstsr::prelude::*;
487/// # let mut device = DeviceCpu::default();
488/// // set to row-major order
489/// device.set_default_order(RowMajor);
490/// // a: [[0, 1, 2], [3, 4, 5]]
491/// // b: [[0, 1], [2, 3], [4, 5]]
492/// // iterated sequence: [0, 1, 2, 3, 4, 5]
493/// let a = rt::tensor_from_nested!([[0, 1, 2], [3, 4, 5]], &device);
494/// let b = a.reshape([3, 2]);
495/// let b_expected = rt::tensor_from_nested!([[0, 1], [2, 3], [4, 5]], &device);
496/// assert!(rt::allclose(&b, &b_expected, None));
497/// let a_vec = a.iter().cloned().collect::<Vec<_>>();
498/// let b_vec = b.iter().cloned().collect::<Vec<_>>();
499/// assert_eq!(a_vec, b_vec); // iterated sequence is the same
500/// assert_eq!(a_vec, vec![0, 1, 2, 3, 4, 5]);
501/// ```
502///
503/// In the column-major order, reshape the same matrix of (2, 3) to (3, 2) will yield a different
504/// result:
505///
506/// ```rust
507/// # use rstsr::prelude::*;
508/// # let mut device = DeviceCpu::default();
509/// // set to column-major order
510/// device.set_default_order(ColMajor);
511/// // a: [[0, 1, 2], [3, 4, 5]]
512/// // b: [[0, 4], [3, 2], [1, 5]]
513/// // iterated sequence: [0, 3, 1, 4, 2, 5]
514/// let a = rt::tensor_from_nested!([[0, 1, 2], [3, 4, 5]], &device);
515/// let b = a.reshape([3, 2]);
516/// let b_expected = rt::tensor_from_nested!([[0, 4], [3, 2], [1, 5]], &device);
517/// assert!(rt::allclose(&b, &b_expected, None));
518/// let a_vec = a.iter().cloned().collect::<Vec<_>>();
519/// let b_vec = b.iter().cloned().collect::<Vec<_>>();
520/// assert_eq!(a_vec, b_vec); // iterated sequence is the same
521/// assert_eq!(a_vec, vec![0, 3, 1, 4, 2, 5]);
522/// ```
523///
524/// ## Occasions of data cloning
525///
526/// The following discussion assumes the tensor is in row-major order. Similar discussion applies to
527/// column-major order.
528///
529/// If the tensor to be reshaped is already in C-contiguous if the device is also row-major, or
530/// F-contiguous if the device is column-major, then the reshape operation can be performed without
531/// any data cloning.
532///
533/// Otherwise, whether data cloning is necessary depends. For example, consider a tensor of shape
534/// (4, 6, 9) but with non-contiguous strides:
535///
536/// ```rust
537/// # use rstsr::prelude::*;
538/// # let mut device = DeviceCpu::default();
539/// # device.set_default_order(RowMajor);
540/// // shape: (4, 6, 9), stride: (72, 9, 1), not c-contiguous
541/// // contiguous situation: (4, [6, 9]), or say the last two dimensions are contiguous
542/// let a = rt::arange((288, &device)).into_shape([4, 8, 9]).into_slice((.., 0..6, ..));
543/// assert_eq!(a.shape(), &[4, 6, 9]);
544/// assert_eq!(a.stride(), &[72, 9, 1]);
545/// assert!(!a.c_contig());
546/// ```
547///
548/// Those cases will not require data cloning (returns a view, or [`DataCow::Ref`] internally):
549///
550/// ```rust
551/// # use rstsr::prelude::*;
552/// # let mut device = DeviceCpu::default();
553/// # device.set_default_order(RowMajor);
554/// # let a = rt::arange((288, &device)).into_shape([4, 8, 9]).into_slice((.., 0..6, ..));
555/// // split a single dimension into multiple dimensions
556/// assert!(!a.reshape([2, 2, 6, 9]).is_owned()); // (4, 6, 9) -> ([2, 2], 6, 9)
557/// assert!(!a.reshape([4, 3, 2, 9]).is_owned()); // (4, 6, 9) -> (4, [3, 2], 9)
558/// assert!(!a.reshape([4, 2, 3, 3, 3]).is_owned()); // (4, 6, 9) -> (4, [2, 3], [3, 3])
559///
560/// // merge contiguous dimensions into a single dimension
561/// assert!(!a.reshape([4, 54]).is_owned()); // (4, 6, 9) -> (4, 6 * 9)
562///
563/// // merge contiguous dimensions and then split
564/// assert!(!a.reshape([4, 3, 6, 3]).is_owned()); // (4, [6, 9]) -> (4, [3, 6, 3])
565/// ```
566///
567/// However, the following cases will require data cloning (returns an owned tensor, or
568/// [`DataCow::Owned`] internally):
569///
570/// ```rust
571/// # use rstsr::prelude::*;
572/// # let mut device = DeviceCpu::default();
573/// # device.set_default_order(RowMajor);
574/// # let a = rt::arange((288, &device)).into_shape([4, 8, 9]).into_slice((.., 0..6, ..));
575/// assert!(a.reshape([24, 9]).is_owned()); // (4, 6, 9) -> (4 * 6, 9)
576/// assert!(a.reshape(-1).is_owned()); // (4, 6, 9) -> (4 * 6 * 9)
577/// assert!(a.reshape([12, 2, 9]).is_owned()); // (4, 6, 9) -> (4 * [3, 2], 9)
578/// ```
579///
580/// # See also
581///
582/// ## Similar function from other crates/libraries
583///
584/// - Python Array API standard: [`reshape`](https://data-apis.org/array-api/2024.12/API_specification/generated/array_api.reshape.html)
585/// - NumPy: [`reshape`](https://numpy.org/doc/stable/reference/generated/numpy.reshape.html)
586/// - ndarray: [`to_shape`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html#method.to_shape)
587///
588/// ## Related functions in RSTSR
589///
590/// - [`reshape_assume_contig`]: Reshape assuming the tensor is contiguous.
591/// - [`layout_reshapeable`]: Check whether the layout is compatible with the new shape.
592/// - [`to_layout`]: Return a tensor with the specified layout.
593/// - [`to_contig`]: Return an owned contiguous tensor.
594///
595/// ## Variants of this function
596///
597/// - [`reshape`] / [`reshape_f`]: Taking reference and returning Cow.
598/// - [`into_shape`] / [`into_shape_f`]: Taking ownership and returning owned tensor.
599/// - [`change_shape`] / [`change_shape_f`]: Taking ownership and returning Cow.
600/// - [`to_shape`] / [`to_shape_f`]: Alias to [`reshape`] / [`reshape_f`].
601/// - Associated methods on [`TensorAny`]:
602///
603///   - [`TensorAny::reshape`] / [`TensorAny::reshape_f`]
604///   - [`TensorAny::into_shape`] / [`TensorAny::into_shape_f`]
605///   - [`TensorAny::change_shape`] / [`TensorAny::change_shape_f`]
606///   - [`TensorAny::to_shape`] / [`TensorAny::to_shape_f`]
607pub fn reshape<'a, I, R, T, B, D>(tensor: &'a TensorAny<R, T, B, D>, shape: I) -> TensorCow<'a, T, B, IxD>
608where
609    I: TryInto<AxesIndex<isize>, Error = Error>,
610    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
611    D: DimAPI,
612    B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
613{
614    to_shape(tensor, shape)
615}
616
617/// Reshapes the given tensor to the specified shape.
618///
619/// # See also
620///
621/// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
622/// examples.
623impl<'a, R, T, B, D> TensorAny<R, T, B, D>
624where
625    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
626    D: DimAPI,
627    B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
628    T: Clone,
629{
630    /// Reshapes the given tensor to the specified shape.
631    ///
632    /// # See also
633    ///
634    /// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
635    /// examples.
636    pub fn change_shape_f<I>(self, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
637    where
638        I: TryInto<AxesIndex<isize>, Error = Error>,
639    {
640        change_shape_f(self, shape)
641    }
642
643    /// Reshapes the given tensor to the specified shape.
644    ///
645    /// # See also
646    ///
647    /// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
648    /// examples.
649    pub fn change_shape<I>(self, shape: I) -> TensorCow<'a, T, B, IxD>
650    where
651        I: TryInto<AxesIndex<isize>, Error = Error>,
652    {
653        change_shape(self, shape)
654    }
655
656    /// Reshapes the given tensor to the specified shape.
657    ///
658    /// # See also
659    ///
660    /// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
661    /// examples.
662    pub fn into_shape_f<I>(self, shape: I) -> Result<Tensor<T, B, IxD>>
663    where
664        I: TryInto<AxesIndex<isize>, Error = Error>,
665        <B as DeviceRawAPI<T>>::Raw: Clone + 'a,
666        B: OpAssignAPI<T, IxD>,
667    {
668        into_shape_f(self, shape)
669    }
670
671    /// Reshapes the given tensor to the specified shape.
672    ///
673    /// # See also
674    ///
675    /// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
676    /// examples.
677    pub fn into_shape<I>(self, shape: I) -> Tensor<T, B, IxD>
678    where
679        I: TryInto<AxesIndex<isize>, Error = Error>,
680        <B as DeviceRawAPI<T>>::Raw: Clone + 'a,
681        B: OpAssignAPI<T, IxD>,
682    {
683        into_shape(self, shape)
684    }
685
686    /// Reshapes the given tensor to the specified shape.
687    ///
688    /// # See also
689    ///
690    /// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
691    /// examples.
692    pub fn to_shape_f<I>(&'a self, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
693    where
694        I: TryInto<AxesIndex<isize>, Error = Error>,
695    {
696        self.view().change_shape_f(shape)
697    }
698
699    /// Reshapes the given tensor to the specified shape.
700    ///
701    /// # See also
702    ///
703    /// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
704    /// examples.
705    pub fn to_shape<I>(&'a self, shape: I) -> TensorCow<'a, T, B, IxD>
706    where
707        I: TryInto<AxesIndex<isize>, Error = Error>,
708    {
709        self.view().change_shape(shape)
710    }
711
712    /// Reshapes the given tensor to the specified shape.
713    ///
714    /// # See also
715    ///
716    /// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
717    /// examples.
718    pub fn reshape_f<I>(&'a self, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
719    where
720        I: TryInto<AxesIndex<isize>, Error = Error>,
721    {
722        self.view().change_shape_f(shape)
723    }
724
725    /// Reshapes the given tensor to the specified shape.
726    ///
727    /// # See also
728    ///
729    /// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
730    /// examples.
731    pub fn reshape<I>(&'a self, shape: I) -> TensorCow<'a, T, B, IxD>
732    where
733        I: TryInto<AxesIndex<isize>, Error = Error>,
734    {
735        self.view().change_shape(shape)
736    }
737}
738
739/* #endregion */
740
741#[cfg(test)]
742mod tests {
743    #[test]
744    #[rustfmt::skip]
745    fn doc_reshape() {
746        use rstsr::prelude::*;
747        let mut device = DeviceCpu::default();
748        device.set_default_order(RowMajor);
749
750        let a = rt::arange((6, &device));
751        let a_reshaped = a.reshape([2, 3]);
752        let a_expected = rt::tensor_from_nested!(
753            [[0, 1, 2], [3, 4, 5]],
754            &device);
755        assert!(rt::allclose(&a_reshaped, &a_expected, None));
756
757        // in this case, unspecified axes length is inferred as 6 / 3 = 2
758        let a_reshaped = a.reshape([3, -1]);
759        let a_expected = rt::tensor_from_nested!(
760            [[0, 1], [2, 3], [4, 5]],
761            &device);
762        assert!(rt::allclose(&a_reshaped, &a_expected, None));
763    }
764
765    #[test]
766    fn doc_reshape_elaborated_diff_row_col() {
767        use rstsr::prelude::*;
768
769        let mut device = DeviceCpu::default();
770        device.set_default_order(RowMajor);
771        let a = rt::tensor_from_nested!([[0, 1, 2], [3, 4, 5]], &device);
772        let b = a.reshape([3, 2]);
773        let a_vec = a.iter().collect::<Vec<_>>();
774        let b_vec = b.iter().collect::<Vec<_>>();
775        assert_eq!(a_vec, b_vec); // iterated sequence is the same
776
777        let mut device = DeviceCpu::default();
778        device.set_default_order(ColMajor);
779        let a = rt::tensor_from_nested!([[0, 1, 2], [3, 4, 5]], &device);
780        let b = a.reshape([3, 2]);
781        let a_c_vec = a.iter().collect::<Vec<_>>();
782        let b_c_vec = b.iter().collect::<Vec<_>>();
783        assert_eq!(a_c_vec, b_c_vec); // iterated sequence is the same
784        assert_ne!(a_c_vec, a_vec); // iterated sequence is different from row-major
785
786        // Row-major reshape
787        let mut device = DeviceCpu::default();
788        device.set_default_order(RowMajor);
789        // a: [[0, 1, 2], [3, 4, 5]]
790        // b: [[0, 1], [2, 3], [4, 5]]
791        // iterated sequence: [0, 1, 2, 3, 4, 5]
792        let a = rt::tensor_from_nested!([[0, 1, 2], [3, 4, 5]], &device);
793        let b = a.reshape([3, 2]);
794        let b_expected = rt::tensor_from_nested!([[0, 1], [2, 3], [4, 5]], &device);
795        assert!(rt::allclose(&b, &b_expected, None));
796        let a_vec = a.iter().cloned().collect::<Vec<_>>();
797        let b_vec = b.iter().cloned().collect::<Vec<_>>();
798        assert_eq!(a_vec, b_vec); // iterated sequence is the same
799        assert_eq!(a_vec, vec![0, 1, 2, 3, 4, 5]);
800
801        // Column-major reshape
802        let mut device = DeviceCpu::default();
803        device.set_default_order(ColMajor);
804        // a: [[0, 1, 2], [3, 4, 5]]
805        // b: [[0, 4], [3, 2], [1, 5]]
806        // iterated sequence: [0, 3, 1, 4, 2, 5]
807        let a = rt::tensor_from_nested!([[0, 1, 2], [3, 4, 5]], &device);
808        let b = a.reshape([3, 2]);
809        let b_expected = rt::tensor_from_nested!([[0, 4], [3, 2], [1, 5]], &device);
810        assert!(rt::allclose(&b, &b_expected, None));
811        let a_vec = a.iter().cloned().collect::<Vec<_>>();
812        let b_vec = b.iter().cloned().collect::<Vec<_>>();
813        assert_eq!(a_vec, b_vec); // iterated sequence is the same
814        assert_eq!(a_vec, vec![0, 3, 1, 4, 2, 5]);
815    }
816
817    #[test]
818    fn doc_reshape_elaborated_clone_occasion() {
819        use rstsr::prelude::*;
820
821        let mut device = DeviceCpu::default();
822        device.set_default_order(RowMajor);
823
824        // some strided tensor
825        // shape: (4, 6, 9), stride: (72, 9, 1), not c-contiguous
826        // contiguous situation: (4, [6, 9]), or say the last two dimensions are contiguous
827        let a = rt::arange((288, &device)).into_shape([4, 8, 9]).into_slice((.., 0..6, ..));
828        assert_eq!(a.shape(), &[4, 6, 9]);
829        assert_eq!(a.stride(), &[72, 9, 1]);
830        assert!(!a.c_contig());
831
832        // reshape that does not require clone (outputs tensor view)
833
834        // split a single dimension into multiple dimensions
835        assert!(!a.reshape([2, 2, 6, 9]).is_owned()); // (4, 6, 9) -> ([2, 2], 6, 9)
836        assert!(!a.reshape([4, 3, 2, 9]).is_owned()); // (4, 6, 9) -> (4, [3, 2], 9)
837        assert!(!a.reshape([4, 2, 3, 3, 3]).is_owned()); // (4, 6, 9) -> (4, [2, 3], [3, 3])
838
839        // merge contiguous dimensions into a single dimension
840        assert!(!a.reshape([4, 54]).is_owned()); // (4, 6, 9) -> (4, 6 * 9)
841
842        // merge contiguous dimensions and then split
843        assert!(!a.reshape([4, 3, 6, 3]).is_owned()); // (4, [6, 9]) -> (4, [3, 6, 3])
844
845        // reshape that requires clone (outputs owned tensor)
846
847        // merge non-contiguous dimensions
848        assert!(a.reshape([24, 9]).is_owned()); // (4, 6, 9) -> (4 * 6, 9)
849        assert!(a.reshape(-1).is_owned()); // (4, 6, 9) -> (4 * 6 * 9)
850        assert!(a.reshape([12, 2, 9]).is_owned()); // (4, 6, 9) -> (4 * [3, 2], 9)
851    }
852
853    #[test]
854    fn doc_into_shape() {
855        use rstsr::prelude::*;
856        let mut device = DeviceCpu::default();
857        device.set_default_order(RowMajor);
858
859        let a = rt::arange((6, &device)).into_shape([2, 3]);
860        println!("a: {:?}", a);
861
862        // shape: (4, 6, 9), stride: (-54, 9, 1), not c-contiguous
863        // contiguous situation: (4, [6, 9]); the first dimension is reversed
864        let a = rt::arange((216, &device)).into_shape([4, 6, 9]).into_flip(0);
865        let a_ptr = a.raw().as_ptr();
866        let b = a.into_shape([4, 54]);
867        let b_ptr = b.raw().as_ptr();
868        assert_eq!(a_ptr, b_ptr); // contiguous dims merged, no data clone happened
869
870        // shape: (4, 6, 9), stride: (-54, 9, 1), not c-contiguous
871        // contiguous situation: (4, [6, 9]); the first dimension is reversed
872        let a = rt::arange((216, &device)).into_shape([4, 6, 9]).into_flip(0);
873        let a_ptr = a.raw().as_ptr();
874        let b = a.into_shape([24, 9]);
875        let b_ptr = b.raw().as_ptr();
876        assert_ne!(a_ptr, b_ptr); // layout not compatible, data clone happened
877
878        // shape: (4, 6, 9), stride: (72, 9, 1), not c-contiguous
879        // contiguous situation: (4, [6, 9]), or say the last two dimensions are contiguous
880        let a = rt::arange((288, &device)).into_shape([4, 8, 9]).into_slice((.., 0..6, ..));
881        let a_ptr = a.raw().as_ptr();
882        let b = a.into_shape([4, 54]);
883        let b_ptr = b.raw().as_ptr();
884        assert_ne!(a_ptr, b_ptr); // layout-compatible, but input tensor is not compact (216 < 288)
885    }
886}