rstsr_core/tensor/manuplication/reshape.rs
1use crate::prelude_dev::*;
2
3/* #region reshape */
4
5/// Reshapes the given tensor to the specified shape.
6///
7/// # See also
8///
9/// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
10/// examples.
11pub fn change_shape_f<'a, I, R, T, B, D>(tensor: TensorAny<R, T, B, D>, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
12where
13 I: TryInto<AxesIndex<isize>, Error = Error>,
14 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
15 D: DimAPI,
16 B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
17{
18 // own shape, this is cheap operation
19 let shape_new = reshape_substitute_negatives(shape.try_into()?.as_ref(), tensor.size())?;
20 let default_order = tensor.device().default_order();
21 if let Some(layout_new) = layout_reshapeable(&tensor.layout().to_dim()?, &shape_new, default_order)? {
22 // shape does not need to be changed
23 let (storage, _) = tensor.into_raw_parts();
24 let layout = layout_new.into_dim::<IxD>()?;
25 return unsafe { Ok(TensorBase::new_unchecked(storage, layout).into_cow()) };
26 } else {
27 // clone underlying data by assign_arbitary
28 let (storage, layout) = tensor.into_raw_parts();
29 let device = storage.device();
30 let layout_new = match default_order {
31 RowMajor => shape_new.new_c_contig(None),
32 ColMajor => shape_new.new_f_contig(None),
33 };
34 let mut storage_new = device.uninit_impl(layout_new.size())?;
35 device.assign_arbitary_uninit(storage_new.raw_mut(), &layout_new, storage.raw(), &layout)?;
36 let storage_new = unsafe { B::assume_init_impl(storage_new)? };
37 return unsafe { Ok(TensorBase::new_unchecked(storage_new, layout_new).into_cow()) };
38 }
39}
40
41/// Reshapes the given tensor to the specified shape.
42///
43/// This function is not intended to be used by usual users. Please consider using
44/// [`reshape`] (take reference of tensor) or [`into_shape`] (take ownership of tensor)
45/// instead.
46///
47/// <div class="warning">
48///
49/// **Row/Column Major Notice**
50///
51/// This function behaves differently on default orders ([`RowMajor`] and [`ColMajor`]) of device.
52///
53/// </div>
54///
55/// # Parameters
56///
57/// - `tensor`: [`TensorAny<R, T, B, D>`]
58///
59/// - The input tensor to be reshaped.
60/// - Ownership of input tensor is taken.
61///
62/// - `axes`: TryInto [`AxesIndex<isize>`]
63///
64/// - Position in the expanded axes where the new axis (or axes) is placed.
65/// - Can be a single integer, or a list/tuple of integers.
66/// - Negative values are supported and indicate counting dimensions from the back.
67///
68/// # Returns
69///
70/// - [`TensorCow<'a, T, B, IxD>`](TensorCow)
71///
72/// - The reshaped tensor.
73/// - This function will try to avoid data cloning if possible.
74///
75/// - If layout-compatible, depending on whether the input tensor is owned or other cases,
76/// either a view or owned tensor will be returned.
77/// - If layout-not-compatible, an owned tensor will be returned, cloning the data.
78/// - Cow (Clone-on-Write) semantics is used for representing either view or owned tensor.
79///
80/// This function is different to [`reshape`], in that it takes ownership of the input
81/// tensor.
82///
83/// This function is also different to [`into_shape`], in that it may return a view, if the input
84/// tensor also have the ownership of tensor view, and the layout is compatible.
85///
86/// # See also
87///
88/// Refer to [`reshape`] for more details and examples.
89pub fn change_shape<'a, I, R, T, B, D>(tensor: TensorAny<R, T, B, D>, shape: I) -> TensorCow<'a, T, B, IxD>
90where
91 I: TryInto<AxesIndex<isize>, Error = Error>,
92 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
93 D: DimAPI,
94 B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
95{
96 change_shape_f(tensor, shape).rstsr_unwrap()
97}
98
99/// Reshapes the given tensor to the specified shape.
100///
101/// # See also
102///
103/// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
104/// examples.
105pub fn into_shape_f<'a, I, R, T, B, D>(tensor: TensorAny<R, T, B, D>, shape: I) -> Result<Tensor<T, B, IxD>>
106where
107 I: TryInto<AxesIndex<isize>, Error = Error>,
108 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
109 D: DimAPI,
110 T: Clone,
111 B: DeviceAPI<T>
112 + DeviceRawAPI<MaybeUninit<T>>
113 + DeviceCreationAnyAPI<T>
114 + OpAssignArbitaryAPI<T, IxD, D>
115 + OpAssignAPI<T, IxD>,
116 <B as DeviceRawAPI<T>>::Raw: Clone + 'a,
117{
118 change_shape_f(tensor, shape).map(|v| v.into_owned())
119}
120
121/// Reshapes the given tensor to the specified shape.
122///
123/// <div class="warning">
124///
125/// **Row/Column Major Notice**
126///
127/// This function behaves differently on default orders ([`RowMajor`] and [`ColMajor`]) of device.
128///
129/// </div>
130///
131/// # Parameters
132///
133/// - `tensor`: [`TensorAny<R, T, B, D>`]
134///
135/// - The input tensor to be reshaped.
136/// - Ownership of input tensor is taken.
137///
138/// - `axes`: TryInto [`AxesIndex<isize>`]
139///
140/// - Position in the expanded axes where the new axis (or axes) is placed.
141/// - Can be a single integer, or a list/tuple of integers.
142/// - Negative values are supported and indicate counting dimensions from the back.
143///
144/// # Returns
145///
146/// - [`Tensor<T, B, IxD>`]
147///
148/// - The reshaped tensor.
149/// - This function will try to avoid data cloning if possible, but with strict conditions:
150///
151/// - Layout-compatible after reshaping;
152/// - Input tensor owns the underlying data (i.e., not a view);
153/// - The input tensor is compact in memory (i.e., the underlying data does not have redundant
154/// elements; size of tensor exactly matches the length of underlying data).
155///
156/// This function is different to [`change_shape`](change_shape()) and [`reshape`], in
157/// that it takes ownership of the input tensor, and always returns an owned tensor.
158///
159/// # Examples
160///
161/// ```rust
162/// use rstsr::prelude::*;
163/// let a = rt::arange(6).into_shape([2, 3]);
164/// ```
165///
166/// # Elaborated examples
167///
168/// Here is some showcases that demonstrate when data cloning happens or not. All examples are
169/// row-major.
170///
171/// A first case is a tensor that is not fully contiguous (containing negative strides), but the
172/// tensor is compact (size of tensor is the same to the length of underlying data). In this case,
173/// if the new shape is compatible, no data cloning happens:
174///
175/// ```rust
176/// # use rstsr::prelude::*;
177/// # let mut device = DeviceCpu::default();
178/// # device.set_default_order(RowMajor);
179/// // shape: (4, 6, 9), stride: (-54, 9, 1), not c-contiguous
180/// // contiguous situation: (4, [6, 9]); the first dimension is reversed
181/// let a = rt::arange((216, &device)).into_shape([4, 6, 9]).into_flip(0);
182/// let a_ptr = a.raw().as_ptr();
183/// let b = a.into_shape([4, 54]);
184/// let b_ptr = b.raw().as_ptr();
185/// assert_eq!(a_ptr, b_ptr); // contiguous dims merged, no data clone happened
186/// ```
187///
188/// However, if the new shape is not compatible, data cloning will happen:
189///
190/// ```rust
191/// # use rstsr::prelude::*;
192/// # let mut device = DeviceCpu::default();
193/// # device.set_default_order(RowMajor);
194/// // shape: (4, 6, 9), stride: (-54, 9, 1), not c-contiguous
195/// // contiguous situation: (4, [6, 9]); the first dimension is reversed
196/// let a = rt::arange((216, &device)).into_shape([4, 6, 9]).into_flip(0);
197/// let a_ptr = a.raw().as_ptr();
198/// let b = a.into_shape([24, 9]);
199/// let b_ptr = b.raw().as_ptr();
200/// assert_ne!(a_ptr, b_ptr); // layout not compatible, data clone happened
201/// ```
202///
203/// Another case is a tensor that is not compact (size of tensor is less than the length of
204/// underlying data). In this case, even if the new shape is compatible, data cloning will happen:
205///
206/// ```rust
207/// # use rstsr::prelude::*;
208/// # let mut device = DeviceCpu::default();
209/// # device.set_default_order(RowMajor);
210/// // shape: (4, 6, 9), stride: (72, 9, 1), not c-contiguous
211/// // contiguous situation: (4, [6, 9]), or say the last two dimensions are contiguous
212/// let a = rt::arange((288, &device)).into_shape([4, 8, 9]).into_slice((.., 0..6, ..));
213/// let a_ptr = a.raw().as_ptr();
214/// let b = a.into_shape([4, 54]);
215/// let b_ptr = b.raw().as_ptr();
216/// assert_ne!(a_ptr, b_ptr); // layout-compatible, but input tensor is not compact (216 < 288)
217/// ```
218///
219/// # See also
220///
221/// Refer to [`reshape`] for more details and examples.
222pub fn into_shape<'a, I, R, T, B, D>(tensor: TensorAny<R, T, B, D>, shape: I) -> Tensor<T, B, IxD>
223where
224 I: TryInto<AxesIndex<isize>, Error = Error>,
225 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
226 D: DimAPI,
227 T: Clone,
228 B: DeviceAPI<T>
229 + DeviceRawAPI<MaybeUninit<T>>
230 + DeviceCreationAnyAPI<T>
231 + OpAssignArbitaryAPI<T, IxD, D>
232 + OpAssignAPI<T, IxD>,
233 <B as DeviceRawAPI<T>>::Raw: Clone + 'a,
234{
235 into_shape_f(tensor, shape).rstsr_unwrap()
236}
237
238/// Reshapes the given tensor to the specified shape.
239///
240/// # See also
241///
242/// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
243/// examples.
244pub fn to_shape_f<'a, I, R, T, B, D>(tensor: &'a TensorAny<R, T, B, D>, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
245where
246 I: TryInto<AxesIndex<isize>, Error = Error>,
247 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
248 D: DimAPI,
249 B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
250{
251 change_shape_f(tensor.view(), shape)
252}
253
254/// Reshapes the given tensor to the specified shape.
255///
256/// # See also
257///
258/// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
259/// examples.
260pub fn to_shape<'a, I, R, T, B, D>(tensor: &'a TensorAny<R, T, B, D>, shape: I) -> TensorCow<'a, T, B, IxD>
261where
262 I: TryInto<AxesIndex<isize>, Error = Error>,
263 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
264 D: DimAPI,
265 B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
266{
267 to_shape_f(tensor, shape).rstsr_unwrap()
268}
269
270/// Reshapes the given tensor to the specified shape.
271///
272/// # See also
273///
274/// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
275/// examples.
276pub fn reshape_f<'a, I, R, T, B, D>(tensor: &'a TensorAny<R, T, B, D>, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
277where
278 I: TryInto<AxesIndex<isize>, Error = Error>,
279 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
280 D: DimAPI,
281 B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
282{
283 to_shape_f(tensor, shape)
284}
285
286/// Reshapes the given tensor to the specified shape.
287///
288/// <div class="warning">
289///
290/// **Row/Column Major Notice**
291///
292/// This function behaves differently on default orders ([`RowMajor`] and [`ColMajor`]) of device.
293///
294/// </div>
295///
296/// # Parameters
297///
298/// - `tensor`: [`&TensorAny<R, T, B, D>`](TensorAny)
299///
300/// - The input tensor to be reshaped.
301///
302/// - `axes`: TryInto [`AxesIndex<isize>`]
303///
304/// - Position in the expanded axes where the new axis (or axes) is placed.
305/// - Can be a single integer, or a list/tuple of integers.
306/// - Negative values are supported and indicate counting dimensions from the back.
307///
308/// # Returns
309///
310/// - [`TensorCow<'a, T, B, IxD>`](TensorCow)
311///
312/// - The reshaped tensor.
313/// - This function will try to avoid data cloning if possible.
314///
315/// - If layout-compatible, a view will be returned.
316/// - If shape-not-compatible, an owned tensor will be returned, cloning the data.
317/// - Cow (Clone-on-Write) semantics is used for representing either view or owned tensor.
318///
319/// # Examples
320///
321/// In row-major order, to reshape a vector of (6, ) to a matrix of (2, 3):
322/// ```rust
323/// use rstsr::prelude::*;
324/// let mut device = DeviceCpu::default();
325/// device.set_default_order(RowMajor);
326///
327/// let a = rt::arange((6, &device));
328/// let a_reshaped = a.reshape([2, 3]);
329/// let a_expected = rt::tensor_from_nested!(
330/// [[0, 1, 2], [3, 4, 5]],
331/// &device);
332/// assert!(rt::allclose(&a_reshaped, &a_expected, None));
333/// ```
334///
335/// You can also use negative dimension, where -1 means "infer this dimension":
336///
337/// ```rust
338/// # use rstsr::prelude::*;
339/// # let mut device = DeviceCpu::default();
340/// # device.set_default_order(RowMajor);
341/// #
342/// // in this case, unspecified axes length is inferred as 6 / 3 = 2
343/// let a_reshaped = a.reshape([3, -1]);
344/// let a_expected = rt::tensor_from_nested!(
345/// [[0, 1], [2, 3], [4, 5]],
346/// &device);
347/// assert!(rt::allclose(&a_reshaped, &a_expected, None));
348/// ```
349///
350/// # Ownership Semantics between [`reshape`], [`into_shape`] and [`change_shape`]
351///
352/// [`into_shape`] and [`change_shape`] take ownership of the input tensor. They are important
353/// variants to this function [`reshape`].
354///
355/// | Function | Input Ownership | Output Ownership | Cloning Condition |
356/// |--|--|--|--|
357/// | [`reshape`] | Borrowed <br> [`&TensorAny`](TensorAny) | View <br> [`TensorCow`] with [`DataCow::Ref`] | not cloned (layout-compatible) |
358/// | | | Owned <br> [`TensorCow`] with [`DataCow::Owned`] | cloned (layout-not-compatible) |
359/// | [`into_shape`] | Owned <br> [`Tensor`] | Owned <br> [`Tensor`] | not cloned (layout-compatible, input tensor owns data, input tensor is compact) |
360/// | | | Owned <br> [`Tensor`] | cloned (otherwise) |
361/// | | Otherwise <br> [`TensorAny`] | Owned <br> [`Tensor`] | cloned (always) |
362/// | [`change_shape`] | Owned <br> [`Tensor`] | Owned <br> [`TensorCow`] with [`DataCow::Owned`] | not cloned (layout-compatible, input tensor owns data, input tensor is compact) |
363/// | | | Owned <br> [`TensorCow`] with [`DataCow::Owned`] | cloned (otherwise) |
364/// | | Otherwise <br> [`TensorAny`] | View <br> [`TensorCow`] with [`DataCow::Ref`] | not cloned (layout-compatible) |
365/// | | | Owned <br> [`TensorCow`] with [`DataCow::Owned`] | cloned (layout-not-compatible) |
366///
367/// # Tips on common compilation errors
368///
369/// You may encounter ownership problem when you try to assign a reshaped tensor like this:
370///
371/// ```rust,should_panic
372/// # use rstsr::prelude::*;
373/// # let mut device = DeviceCpu::default();
374/// # device.set_default_order(RowMajor);
375/// let a = rt::arange((6, &device)).reshape([2, 3]);
376/// ```
377///
378/// The compiler may give an error like:
379///
380/// ```text
381/// 704 | let a = rt::arange((6, &device)).reshape([2, 3]);
382/// | ^^^^^^^^^^^^^^^^^^^^^^^^ - temporary value is freed at the end of this statement
383/// | |
384/// | creates a temporary value which is freed while still in use
385/// 705 | println!("a: {:?}", a);
386/// | - borrow later used here
387/// |
388/// help: consider using a `let` binding to create a longer lived value
389/// |
390/// 704 ~ let binding = rt::arange((6, &device));
391/// 705 ~ let a = binding.reshape([2, 3]);
392/// |
393/// ```
394///
395/// The suggestion by compiler is correct. However, you have another simpler way to solve this
396/// problem by using [`into_shape`] variant that takes ownership:
397///
398/// ```rust
399/// # use rstsr::prelude::*;
400/// # let mut device = DeviceCpu::default();
401/// # device.set_default_order(RowMajor);
402/// let a = rt::arange((6, &device)).into_shape([2, 3]);
403/// ```
404///
405/// # Notes of accordance
406///
407/// ## To Python Array API Standard
408///
409/// This function corresponds to Python Array API Standard:
410/// [`reshape`](https://data-apis.org/array-api/2024.12/API_specification/generated/array_api.reshape.html).
411///
412/// However, please note that this function does not implement the optional keyword `copy` as in
413/// the standard. `copy` keyword in the standard specifies whether to return a copy of the array
414/// data when the requested shape is not compatible with the original shape.
415///
416/// This function implements `copy = None` behavior in the standard, which means that it will return
417/// a view if possible, and return an owned tensor (cloning the data) if necessary.
418///
419/// To achieve similar functionality of optional keyword `copy`,
420///
421/// - For `copy = True` case, you are recommended to
422///
423/// - use [`into_shape`], which always returns an owned tensor, cloning the data if necessary. But
424/// note the necessity of cloning depends on the layout, and RSTSR may still not explicitly
425/// perform cloning.
426/// - use [`to_contig`], which always returns a contiguous owned tensor, cloning the data if
427/// necessary. But note that this function may still not explicitly perform cloning if the
428/// tensor is already contiguous.
429/// - use [`to_owned`](TensorAny::to_owned) as associated method to give an owned tensor, which
430/// always perform cloning.
431///
432/// - For `copy = False` case, you are recommended to
433///
434/// - use utility function [`layout_reshapeable`] to check whether the layout is compatible with
435/// the new shape.
436///
437/// ## To NumPy
438///
439/// This function corresponds to NumPy:
440/// [`reshape`](https://numpy.org/doc/stable/reference/generated/numpy.reshape.html).
441///
442/// However, please note that this function does not implement the optional keyword `order` as in
443/// the NumPy version. `order` keyword in NumPy specifies the iteration order to read elements from
444/// the tensor to-be-reshaped.
445///
446/// This function uses the device's default order to determine the layout of the reshaped tensor.
447/// You can check the device's current default order by
448/// [`device.default_order`](DeviceBaseAPI::default_order). Also see the elaborated examples below.
449///
450/// To change the device's default order, you can use
451///
452/// - [`device.set_default_order`](DeviceBaseAPI::set_default_order) to set the default order of a
453/// device instance, and then
454/// - [`change_device`](TensorDeviceChangeAPI::change_device) or
455/// [`into_device`](TensorDeviceChangeAPI::into_device) or
456/// [`to_device`](TensorDeviceChangeAPI::to_device) to change the tensor's device to the modified
457/// device. Choose the appropriate method depending on the desired ownership semantics.
458///
459/// # Elaborated examples
460///
461/// ## Difference between [RowMajor] and [ColMajor]
462///
463/// Tensor can be uniquely iterated (into a 1-dimension vector), for either row-major or
464/// column-major order.
465///
466/// **Reshape operation does not change the iterated sequence of a tensor**, by definition. In other
467/// words, the following code always holds true:
468///
469/// ```rust
470/// # use rstsr::prelude::*;
471/// # let mut device = DeviceCpu::default();
472/// # device.set_default_order(ColMajor);
473/// # let a = rt::tensor_from_nested!([[0, 1, 2], [3, 4, 5]], &device);
474/// # let b = a.reshape([3, 2]);
475/// // note iteration order of associated method `iter` depends on `device.default_order()`
476///
477/// // let b = a.reshape(... SOME SHAPE ...);
478/// let a_vec = a.iter().collect::<Vec<_>>();
479/// let b_vec = b.iter().collect::<Vec<_>>();
480/// assert_eq!(a_vec, b_vec); // iterated sequence is the same
481/// ```
482///
483/// For example, in row-major order, reshape a matrix of (2, 3) to (3, 2):
484///
485/// ```rust
486/// # use rstsr::prelude::*;
487/// # let mut device = DeviceCpu::default();
488/// // set to row-major order
489/// device.set_default_order(RowMajor);
490/// // a: [[0, 1, 2], [3, 4, 5]]
491/// // b: [[0, 1], [2, 3], [4, 5]]
492/// // iterated sequence: [0, 1, 2, 3, 4, 5]
493/// let a = rt::tensor_from_nested!([[0, 1, 2], [3, 4, 5]], &device);
494/// let b = a.reshape([3, 2]);
495/// let b_expected = rt::tensor_from_nested!([[0, 1], [2, 3], [4, 5]], &device);
496/// assert!(rt::allclose(&b, &b_expected, None));
497/// let a_vec = a.iter().cloned().collect::<Vec<_>>();
498/// let b_vec = b.iter().cloned().collect::<Vec<_>>();
499/// assert_eq!(a_vec, b_vec); // iterated sequence is the same
500/// assert_eq!(a_vec, vec![0, 1, 2, 3, 4, 5]);
501/// ```
502///
503/// In the column-major order, reshape the same matrix of (2, 3) to (3, 2) will yield a different
504/// result:
505///
506/// ```rust
507/// # use rstsr::prelude::*;
508/// # let mut device = DeviceCpu::default();
509/// // set to column-major order
510/// device.set_default_order(ColMajor);
511/// // a: [[0, 1, 2], [3, 4, 5]]
512/// // b: [[0, 4], [3, 2], [1, 5]]
513/// // iterated sequence: [0, 3, 1, 4, 2, 5]
514/// let a = rt::tensor_from_nested!([[0, 1, 2], [3, 4, 5]], &device);
515/// let b = a.reshape([3, 2]);
516/// let b_expected = rt::tensor_from_nested!([[0, 4], [3, 2], [1, 5]], &device);
517/// assert!(rt::allclose(&b, &b_expected, None));
518/// let a_vec = a.iter().cloned().collect::<Vec<_>>();
519/// let b_vec = b.iter().cloned().collect::<Vec<_>>();
520/// assert_eq!(a_vec, b_vec); // iterated sequence is the same
521/// assert_eq!(a_vec, vec![0, 3, 1, 4, 2, 5]);
522/// ```
523///
524/// ## Occasions of data cloning
525///
526/// The following discussion assumes the tensor is in row-major order. Similar discussion applies to
527/// column-major order.
528///
529/// If the tensor to be reshaped is already in C-contiguous if the device is also row-major, or
530/// F-contiguous if the device is column-major, then the reshape operation can be performed without
531/// any data cloning.
532///
533/// Otherwise, whether data cloning is necessary depends. For example, consider a tensor of shape
534/// (4, 6, 9) but with non-contiguous strides:
535///
536/// ```rust
537/// # use rstsr::prelude::*;
538/// # let mut device = DeviceCpu::default();
539/// # device.set_default_order(RowMajor);
540/// // shape: (4, 6, 9), stride: (72, 9, 1), not c-contiguous
541/// // contiguous situation: (4, [6, 9]), or say the last two dimensions are contiguous
542/// let a = rt::arange((288, &device)).into_shape([4, 8, 9]).into_slice((.., 0..6, ..));
543/// assert_eq!(a.shape(), &[4, 6, 9]);
544/// assert_eq!(a.stride(), &[72, 9, 1]);
545/// assert!(!a.c_contig());
546/// ```
547///
548/// Those cases will not require data cloning (returns a view, or [`DataCow::Ref`] internally):
549///
550/// ```rust
551/// # use rstsr::prelude::*;
552/// # let mut device = DeviceCpu::default();
553/// # device.set_default_order(RowMajor);
554/// # let a = rt::arange((288, &device)).into_shape([4, 8, 9]).into_slice((.., 0..6, ..));
555/// // split a single dimension into multiple dimensions
556/// assert!(!a.reshape([2, 2, 6, 9]).is_owned()); // (4, 6, 9) -> ([2, 2], 6, 9)
557/// assert!(!a.reshape([4, 3, 2, 9]).is_owned()); // (4, 6, 9) -> (4, [3, 2], 9)
558/// assert!(!a.reshape([4, 2, 3, 3, 3]).is_owned()); // (4, 6, 9) -> (4, [2, 3], [3, 3])
559///
560/// // merge contiguous dimensions into a single dimension
561/// assert!(!a.reshape([4, 54]).is_owned()); // (4, 6, 9) -> (4, 6 * 9)
562///
563/// // merge contiguous dimensions and then split
564/// assert!(!a.reshape([4, 3, 6, 3]).is_owned()); // (4, [6, 9]) -> (4, [3, 6, 3])
565/// ```
566///
567/// However, the following cases will require data cloning (returns an owned tensor, or
568/// [`DataCow::Owned`] internally):
569///
570/// ```rust
571/// # use rstsr::prelude::*;
572/// # let mut device = DeviceCpu::default();
573/// # device.set_default_order(RowMajor);
574/// # let a = rt::arange((288, &device)).into_shape([4, 8, 9]).into_slice((.., 0..6, ..));
575/// assert!(a.reshape([24, 9]).is_owned()); // (4, 6, 9) -> (4 * 6, 9)
576/// assert!(a.reshape(-1).is_owned()); // (4, 6, 9) -> (4 * 6 * 9)
577/// assert!(a.reshape([12, 2, 9]).is_owned()); // (4, 6, 9) -> (4 * [3, 2], 9)
578/// ```
579///
580/// # See also
581///
582/// ## Similar function from other crates/libraries
583///
584/// - Python Array API standard: [`reshape`](https://data-apis.org/array-api/2024.12/API_specification/generated/array_api.reshape.html)
585/// - NumPy: [`reshape`](https://numpy.org/doc/stable/reference/generated/numpy.reshape.html)
586/// - ndarray: [`to_shape`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html#method.to_shape)
587///
588/// ## Related functions in RSTSR
589///
590/// - [`reshape_assume_contig`]: Reshape assuming the tensor is contiguous.
591/// - [`layout_reshapeable`]: Check whether the layout is compatible with the new shape.
592/// - [`to_layout`]: Return a tensor with the specified layout.
593/// - [`to_contig`]: Return an owned contiguous tensor.
594///
595/// ## Variants of this function
596///
597/// - [`reshape`] / [`reshape_f`]: Taking reference and returning Cow.
598/// - [`into_shape`] / [`into_shape_f`]: Taking ownership and returning owned tensor.
599/// - [`change_shape`] / [`change_shape_f`]: Taking ownership and returning Cow.
600/// - [`to_shape`] / [`to_shape_f`]: Alias to [`reshape`] / [`reshape_f`].
601/// - Associated methods on [`TensorAny`]:
602///
603/// - [`TensorAny::reshape`] / [`TensorAny::reshape_f`]
604/// - [`TensorAny::into_shape`] / [`TensorAny::into_shape_f`]
605/// - [`TensorAny::change_shape`] / [`TensorAny::change_shape_f`]
606/// - [`TensorAny::to_shape`] / [`TensorAny::to_shape_f`]
607pub fn reshape<'a, I, R, T, B, D>(tensor: &'a TensorAny<R, T, B, D>, shape: I) -> TensorCow<'a, T, B, IxD>
608where
609 I: TryInto<AxesIndex<isize>, Error = Error>,
610 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
611 D: DimAPI,
612 B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
613{
614 to_shape(tensor, shape)
615}
616
617/// Reshapes the given tensor to the specified shape.
618///
619/// # See also
620///
621/// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
622/// examples.
623impl<'a, R, T, B, D> TensorAny<R, T, B, D>
624where
625 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
626 D: DimAPI,
627 B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
628 T: Clone,
629{
630 /// Reshapes the given tensor to the specified shape.
631 ///
632 /// # See also
633 ///
634 /// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
635 /// examples.
636 pub fn change_shape_f<I>(self, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
637 where
638 I: TryInto<AxesIndex<isize>, Error = Error>,
639 {
640 change_shape_f(self, shape)
641 }
642
643 /// Reshapes the given tensor to the specified shape.
644 ///
645 /// # See also
646 ///
647 /// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
648 /// examples.
649 pub fn change_shape<I>(self, shape: I) -> TensorCow<'a, T, B, IxD>
650 where
651 I: TryInto<AxesIndex<isize>, Error = Error>,
652 {
653 change_shape(self, shape)
654 }
655
656 /// Reshapes the given tensor to the specified shape.
657 ///
658 /// # See also
659 ///
660 /// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
661 /// examples.
662 pub fn into_shape_f<I>(self, shape: I) -> Result<Tensor<T, B, IxD>>
663 where
664 I: TryInto<AxesIndex<isize>, Error = Error>,
665 <B as DeviceRawAPI<T>>::Raw: Clone + 'a,
666 B: OpAssignAPI<T, IxD>,
667 {
668 into_shape_f(self, shape)
669 }
670
671 /// Reshapes the given tensor to the specified shape.
672 ///
673 /// # See also
674 ///
675 /// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
676 /// examples.
677 pub fn into_shape<I>(self, shape: I) -> Tensor<T, B, IxD>
678 where
679 I: TryInto<AxesIndex<isize>, Error = Error>,
680 <B as DeviceRawAPI<T>>::Raw: Clone + 'a,
681 B: OpAssignAPI<T, IxD>,
682 {
683 into_shape(self, shape)
684 }
685
686 /// Reshapes the given tensor to the specified shape.
687 ///
688 /// # See also
689 ///
690 /// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
691 /// examples.
692 pub fn to_shape_f<I>(&'a self, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
693 where
694 I: TryInto<AxesIndex<isize>, Error = Error>,
695 {
696 self.view().change_shape_f(shape)
697 }
698
699 /// Reshapes the given tensor to the specified shape.
700 ///
701 /// # See also
702 ///
703 /// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
704 /// examples.
705 pub fn to_shape<I>(&'a self, shape: I) -> TensorCow<'a, T, B, IxD>
706 where
707 I: TryInto<AxesIndex<isize>, Error = Error>,
708 {
709 self.view().change_shape(shape)
710 }
711
712 /// Reshapes the given tensor to the specified shape.
713 ///
714 /// # See also
715 ///
716 /// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
717 /// examples.
718 pub fn reshape_f<I>(&'a self, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
719 where
720 I: TryInto<AxesIndex<isize>, Error = Error>,
721 {
722 self.view().change_shape_f(shape)
723 }
724
725 /// Reshapes the given tensor to the specified shape.
726 ///
727 /// # See also
728 ///
729 /// Refer to [`reshape`], [`into_shape`] and [`change_shape`] for more details and
730 /// examples.
731 pub fn reshape<I>(&'a self, shape: I) -> TensorCow<'a, T, B, IxD>
732 where
733 I: TryInto<AxesIndex<isize>, Error = Error>,
734 {
735 self.view().change_shape(shape)
736 }
737}
738
739/* #endregion */
740
741#[cfg(test)]
742mod tests {
743 #[test]
744 #[rustfmt::skip]
745 fn doc_reshape() {
746 use rstsr::prelude::*;
747 let mut device = DeviceCpu::default();
748 device.set_default_order(RowMajor);
749
750 let a = rt::arange((6, &device));
751 let a_reshaped = a.reshape([2, 3]);
752 let a_expected = rt::tensor_from_nested!(
753 [[0, 1, 2], [3, 4, 5]],
754 &device);
755 assert!(rt::allclose(&a_reshaped, &a_expected, None));
756
757 // in this case, unspecified axes length is inferred as 6 / 3 = 2
758 let a_reshaped = a.reshape([3, -1]);
759 let a_expected = rt::tensor_from_nested!(
760 [[0, 1], [2, 3], [4, 5]],
761 &device);
762 assert!(rt::allclose(&a_reshaped, &a_expected, None));
763 }
764
765 #[test]
766 fn doc_reshape_elaborated_diff_row_col() {
767 use rstsr::prelude::*;
768
769 let mut device = DeviceCpu::default();
770 device.set_default_order(RowMajor);
771 let a = rt::tensor_from_nested!([[0, 1, 2], [3, 4, 5]], &device);
772 let b = a.reshape([3, 2]);
773 let a_vec = a.iter().collect::<Vec<_>>();
774 let b_vec = b.iter().collect::<Vec<_>>();
775 assert_eq!(a_vec, b_vec); // iterated sequence is the same
776
777 let mut device = DeviceCpu::default();
778 device.set_default_order(ColMajor);
779 let a = rt::tensor_from_nested!([[0, 1, 2], [3, 4, 5]], &device);
780 let b = a.reshape([3, 2]);
781 let a_c_vec = a.iter().collect::<Vec<_>>();
782 let b_c_vec = b.iter().collect::<Vec<_>>();
783 assert_eq!(a_c_vec, b_c_vec); // iterated sequence is the same
784 assert_ne!(a_c_vec, a_vec); // iterated sequence is different from row-major
785
786 // Row-major reshape
787 let mut device = DeviceCpu::default();
788 device.set_default_order(RowMajor);
789 // a: [[0, 1, 2], [3, 4, 5]]
790 // b: [[0, 1], [2, 3], [4, 5]]
791 // iterated sequence: [0, 1, 2, 3, 4, 5]
792 let a = rt::tensor_from_nested!([[0, 1, 2], [3, 4, 5]], &device);
793 let b = a.reshape([3, 2]);
794 let b_expected = rt::tensor_from_nested!([[0, 1], [2, 3], [4, 5]], &device);
795 assert!(rt::allclose(&b, &b_expected, None));
796 let a_vec = a.iter().cloned().collect::<Vec<_>>();
797 let b_vec = b.iter().cloned().collect::<Vec<_>>();
798 assert_eq!(a_vec, b_vec); // iterated sequence is the same
799 assert_eq!(a_vec, vec![0, 1, 2, 3, 4, 5]);
800
801 // Column-major reshape
802 let mut device = DeviceCpu::default();
803 device.set_default_order(ColMajor);
804 // a: [[0, 1, 2], [3, 4, 5]]
805 // b: [[0, 4], [3, 2], [1, 5]]
806 // iterated sequence: [0, 3, 1, 4, 2, 5]
807 let a = rt::tensor_from_nested!([[0, 1, 2], [3, 4, 5]], &device);
808 let b = a.reshape([3, 2]);
809 let b_expected = rt::tensor_from_nested!([[0, 4], [3, 2], [1, 5]], &device);
810 assert!(rt::allclose(&b, &b_expected, None));
811 let a_vec = a.iter().cloned().collect::<Vec<_>>();
812 let b_vec = b.iter().cloned().collect::<Vec<_>>();
813 assert_eq!(a_vec, b_vec); // iterated sequence is the same
814 assert_eq!(a_vec, vec![0, 3, 1, 4, 2, 5]);
815 }
816
817 #[test]
818 fn doc_reshape_elaborated_clone_occasion() {
819 use rstsr::prelude::*;
820
821 let mut device = DeviceCpu::default();
822 device.set_default_order(RowMajor);
823
824 // some strided tensor
825 // shape: (4, 6, 9), stride: (72, 9, 1), not c-contiguous
826 // contiguous situation: (4, [6, 9]), or say the last two dimensions are contiguous
827 let a = rt::arange((288, &device)).into_shape([4, 8, 9]).into_slice((.., 0..6, ..));
828 assert_eq!(a.shape(), &[4, 6, 9]);
829 assert_eq!(a.stride(), &[72, 9, 1]);
830 assert!(!a.c_contig());
831
832 // reshape that does not require clone (outputs tensor view)
833
834 // split a single dimension into multiple dimensions
835 assert!(!a.reshape([2, 2, 6, 9]).is_owned()); // (4, 6, 9) -> ([2, 2], 6, 9)
836 assert!(!a.reshape([4, 3, 2, 9]).is_owned()); // (4, 6, 9) -> (4, [3, 2], 9)
837 assert!(!a.reshape([4, 2, 3, 3, 3]).is_owned()); // (4, 6, 9) -> (4, [2, 3], [3, 3])
838
839 // merge contiguous dimensions into a single dimension
840 assert!(!a.reshape([4, 54]).is_owned()); // (4, 6, 9) -> (4, 6 * 9)
841
842 // merge contiguous dimensions and then split
843 assert!(!a.reshape([4, 3, 6, 3]).is_owned()); // (4, [6, 9]) -> (4, [3, 6, 3])
844
845 // reshape that requires clone (outputs owned tensor)
846
847 // merge non-contiguous dimensions
848 assert!(a.reshape([24, 9]).is_owned()); // (4, 6, 9) -> (4 * 6, 9)
849 assert!(a.reshape(-1).is_owned()); // (4, 6, 9) -> (4 * 6 * 9)
850 assert!(a.reshape([12, 2, 9]).is_owned()); // (4, 6, 9) -> (4 * [3, 2], 9)
851 }
852
853 #[test]
854 fn doc_into_shape() {
855 use rstsr::prelude::*;
856 let mut device = DeviceCpu::default();
857 device.set_default_order(RowMajor);
858
859 let a = rt::arange((6, &device)).into_shape([2, 3]);
860 println!("a: {:?}", a);
861
862 // shape: (4, 6, 9), stride: (-54, 9, 1), not c-contiguous
863 // contiguous situation: (4, [6, 9]); the first dimension is reversed
864 let a = rt::arange((216, &device)).into_shape([4, 6, 9]).into_flip(0);
865 let a_ptr = a.raw().as_ptr();
866 let b = a.into_shape([4, 54]);
867 let b_ptr = b.raw().as_ptr();
868 assert_eq!(a_ptr, b_ptr); // contiguous dims merged, no data clone happened
869
870 // shape: (4, 6, 9), stride: (-54, 9, 1), not c-contiguous
871 // contiguous situation: (4, [6, 9]); the first dimension is reversed
872 let a = rt::arange((216, &device)).into_shape([4, 6, 9]).into_flip(0);
873 let a_ptr = a.raw().as_ptr();
874 let b = a.into_shape([24, 9]);
875 let b_ptr = b.raw().as_ptr();
876 assert_ne!(a_ptr, b_ptr); // layout not compatible, data clone happened
877
878 // shape: (4, 6, 9), stride: (72, 9, 1), not c-contiguous
879 // contiguous situation: (4, [6, 9]), or say the last two dimensions are contiguous
880 let a = rt::arange((288, &device)).into_shape([4, 8, 9]).into_slice((.., 0..6, ..));
881 let a_ptr = a.raw().as_ptr();
882 let b = a.into_shape([4, 54]);
883 let b_ptr = b.raw().as_ptr();
884 assert_ne!(a_ptr, b_ptr); // layout-compatible, but input tensor is not compact (216 < 288)
885 }
886}