Skip to main content

rstsr_core/tensor/manuplication/
to_layout.rs

1use crate::prelude_dev::*;
2
3/* #region to_layout */
4
5/// Convert tensor to a specified layout.
6///
7/// See also [`to_layout`].
8pub fn change_layout_f<'a, R, T, B, D, D2>(
9    tensor: TensorAny<R, T, B, D>,
10    layout: Layout<D2>,
11) -> Result<TensorCow<'a, T, B, D2>>
12where
13    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
14    D: DimAPI,
15    D2: DimAPI,
16    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D2, D>,
17{
18    let shape = layout.shape();
19    rstsr_assert_eq!(tensor.size(), shape.shape_size(), InvalidLayout)?;
20    let same_layout = tensor.layout().to_dim::<IxD>()? == layout.to_dim::<IxD>()?;
21    if same_layout {
22        // no data cloned
23        let (storage, _) = tensor.into_raw_parts();
24        let tensor = unsafe { TensorBase::new_unchecked(storage, layout) };
25        return Ok(tensor.into_cow());
26    } else {
27        // layout changed, or not c and f contiguous with same layout
28        // clone data by assign
29        let (storage_old, layout_old) = tensor.into_raw_parts();
30        let device = storage_old.device();
31        let (_, idx_max) = layout.bounds_index()?;
32        let mut storage_new = device.uninit_impl(idx_max)?;
33        device.assign_arbitary_uninit(storage_new.raw_mut(), &layout, storage_old.raw(), &layout_old)?;
34        let storage_new = unsafe { B::assume_init_impl(storage_new)? };
35        let tensor = unsafe { TensorBase::new_unchecked(storage_new, layout) };
36        return Ok(tensor.into_cow());
37    }
38}
39
40/// Convert tensor to a specified layout.
41///
42/// This function takes a reference to a tensor and a target layout, returning a [`TensorCow`]
43/// that is either a view (if the layout matches or both are contiguous) or a newly allocated
44/// copy with the requested layout.
45///
46/// The layout can differ from the original in shape, strides, or even dimensionality,
47/// as long as the total number of elements remains the same.
48///
49/// # Arguments
50///
51/// - `tensor`: A reference to the input tensor.
52/// - `layout`: The target [`Layout`] for the output tensor.
53///
54/// # Returns
55///
56/// A [`TensorCow`] containing either a view (if no copy needed) or an owned tensor with the
57/// specified layout.
58///
59/// # Errors
60///
61/// Returns an error if the layout size doesn't match the tensor size.
62/// Use [`to_layout_f`] for the fallible version.
63///
64/// # Examples
65///
66/// ```rust
67/// # use rstsr::prelude::*;
68/// # let mut device = DeviceCpu::default();
69/// # device.set_default_order(RowMajor);
70/// // Convert tensor to a different layout
71/// let a = rt::arange((12, &device)).into_shape([3, 4]);
72/// println!("a layout: {:?}", a.layout());
73/// // 2-Dim (dyn), contiguous: Cc
74/// // shape: [3, 4], stride: [4, 1], offset: 0
75///
76/// // Convert to F-contiguous layout
77/// let layout_f = [3, 4].f();
78/// let b = a.to_layout(layout_f);
79/// println!("b layout: {:?}", b.layout());
80/// // 2-Dim (dyn), contiguous: Fc
81/// // shape: [3, 4], stride: [1, 3], offset: 0
82/// assert!(b.f_contig());
83/// ```
84///
85/// ```rust
86/// # use rstsr::prelude::*;
87/// # let mut device = DeviceCpu::default();
88/// # device.set_default_order(RowMajor);
89/// // Using to_layout to reshape tensor
90/// let a = rt::arange((12, &device)).into_shape([3, 4]);
91///
92/// // Flatten to 1D
93/// let layout_1d = [12].c();
94/// let b = a.to_layout(layout_1d);
95/// assert_eq!(b.shape(), &[12]);
96///
97/// // Reshape to different 2D
98/// let layout_2d = [2, 6].c();
99/// let c = b.to_layout(layout_2d);
100/// assert_eq!(c.shape(), &[2, 6]);
101/// ```
102///
103/// # See also
104///
105/// ## Similar functions in RSTSR
106///
107/// - [`reshape`]: Change the shape of a tensor (inputs shape instead of layout).
108/// - [`to_contig`]: Convert tensor to C or F contiguous layout.
109/// - [`transpose`]: Permute dimensions of a tensor (returns a view).
110///
111/// ## Variants of this function
112///
113/// - [`to_layout`] / [`to_layout_f`]: Non-consuming version that takes a reference and returns a
114///   view or owned tensor.
115/// - [`into_layout`] / [`into_layout_f`]: Consuming version that returns an owned tensor directly.
116/// - [`change_layout`] / [`change_layout_f`]: Consuming version that returns a view or owned
117///   tensor.
118/// - Associated methods on [`TensorAny`]:
119///
120///   - [`TensorAny::to_layout`] / [`TensorAny::to_layout_f`]
121///   - [`TensorAny::into_layout`] / [`TensorAny::into_layout_f`]
122///   - [`TensorAny::change_layout`] / [`TensorAny::change_layout_f`]
123pub fn to_layout<R, T, D, B, D2>(tensor: &TensorAny<R, T, B, D>, layout: Layout<D2>) -> TensorCow<'_, T, B, D2>
124where
125    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
126    D: DimAPI,
127    D2: DimAPI,
128    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D2, D>,
129{
130    change_layout_f(tensor.view(), layout).rstsr_unwrap()
131}
132
133/// Convert tensor to a specified layout.
134///
135/// See also [`to_layout`].
136pub fn to_layout_f<R, T, D, B, D2>(
137    tensor: &TensorAny<R, T, B, D>,
138    layout: Layout<D2>,
139) -> Result<TensorCow<'_, T, B, D2>>
140where
141    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
142    D: DimAPI,
143    D2: DimAPI,
144    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D2, D>,
145{
146    change_layout_f(tensor.view(), layout)
147}
148
149/// Convert tensor to a specified layout.
150///
151/// See also [`to_layout`].
152pub fn into_layout_f<'a, R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, layout: Layout<D2>) -> Result<Tensor<T, B, D2>>
153where
154    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
155    D: DimAPI,
156    D2: DimAPI,
157    T: Clone,
158    B: DeviceAPI<T>
159        + DeviceRawAPI<MaybeUninit<T>>
160        + DeviceCreationAnyAPI<T>
161        + OpAssignArbitaryAPI<T, D2, D>
162        + OpAssignAPI<T, D2>,
163    <B as DeviceRawAPI<T>>::Raw: Clone + 'a,
164{
165    change_layout_f(tensor, layout).map(|v| v.into_owned())
166}
167
168/// Convert tensor to a specified layout.
169///
170/// See also [`to_layout`].
171pub fn into_layout<'a, R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, layout: Layout<D2>) -> Tensor<T, B, D2>
172where
173    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
174    D: DimAPI,
175    D2: DimAPI,
176    T: Clone,
177    B: DeviceAPI<T>
178        + DeviceRawAPI<MaybeUninit<T>>
179        + DeviceCreationAnyAPI<T>
180        + OpAssignArbitaryAPI<T, D2, D>
181        + OpAssignAPI<T, D2>,
182    <B as DeviceRawAPI<T>>::Raw: Clone + 'a,
183{
184    into_layout_f(tensor, layout).rstsr_unwrap()
185}
186
187/// Convert tensor to a specified layout.
188///
189/// See also [`to_layout`].
190pub fn change_layout<'a, R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, layout: Layout<D2>) -> TensorCow<'a, T, B, D2>
191where
192    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataIntoCowAPI<'a>,
193    D: DimAPI,
194    D2: DimAPI,
195    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D2, D>,
196{
197    change_layout_f(tensor, layout).rstsr_unwrap()
198}
199
200impl<'a, R, T, B, D> TensorAny<R, T, B, D>
201where
202    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
203    D: DimAPI,
204    T: Clone,
205    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
206{
207    /// Convert tensor to a specified layout.
208    ///
209    /// See also [`to_layout`].
210    pub fn to_layout<D2>(&self, layout: Layout<D2>) -> TensorCow<'_, T, B, D2>
211    where
212        D2: DimAPI,
213        B: OpAssignArbitaryAPI<T, D2, D>,
214    {
215        to_layout(self, layout)
216    }
217
218    /// Convert tensor to a specified layout.
219    ///
220    /// See also [`to_layout`].
221    pub fn to_layout_f<D2>(&self, layout: Layout<D2>) -> Result<TensorCow<'_, T, B, D2>>
222    where
223        D2: DimAPI,
224        B: OpAssignArbitaryAPI<T, D2, D>,
225    {
226        to_layout_f(self, layout)
227    }
228
229    /// Convert tensor to a specified layout.
230    ///
231    /// See also [`to_layout`].
232    pub fn into_layout_f<D2>(self, layout: Layout<D2>) -> Result<Tensor<T, B, D2>>
233    where
234        D2: DimAPI,
235        B: DeviceRawAPI<MaybeUninit<T>> + OpAssignArbitaryAPI<T, D2, D> + OpAssignAPI<T, D2>,
236        <B as DeviceRawAPI<T>>::Raw: Clone + 'a,
237    {
238        into_layout_f(self, layout)
239    }
240
241    /// Convert tensor to a specified layout.
242    ///
243    /// See also [`to_layout`].
244    pub fn into_layout<D2>(self, layout: Layout<D2>) -> Tensor<T, B, D2>
245    where
246        D2: DimAPI,
247        B: DeviceRawAPI<MaybeUninit<T>> + OpAssignArbitaryAPI<T, D2, D> + OpAssignAPI<T, D2>,
248        <B as DeviceRawAPI<T>>::Raw: Clone + 'a,
249    {
250        into_layout(self, layout)
251    }
252
253    /// Convert tensor to a specified layout.
254    ///
255    /// See also [`to_layout`].
256    pub fn change_layout_f<D2>(self, layout: Layout<D2>) -> Result<TensorCow<'a, T, B, D2>>
257    where
258        D2: DimAPI,
259        B: OpAssignArbitaryAPI<T, D2, D>,
260    {
261        change_layout_f(self, layout)
262    }
263
264    /// Convert tensor to a specified layout.
265    ///
266    /// See also [`to_layout`].
267    pub fn change_layout<D2>(self, layout: Layout<D2>) -> TensorCow<'a, T, B, D2>
268    where
269        D2: DimAPI,
270        B: OpAssignArbitaryAPI<T, D2, D>,
271    {
272        change_layout(self, layout)
273    }
274}
275
276/* #endregion */