hpt_common/layout/
layout_utils.rs

1use std::panic::Location;
2
3use crate::{
4    axis::axis::{process_axes, Axis},
5    error::{base::TensorError, shape::ShapeError},
6    shape::{
7        shape::Shape,
8        shape_utils::{is_reshape_possible, predict_broadcast_shape, try_pad_shape},
9    },
10    strides::{strides::Strides, strides_utils::shape_to_strides},
11};
12
13use super::layout::Layout;
14
15impl Layout {
16    /// # Internal Function
17    ///
18    /// a function mainly use for checking if the reshape is possible
19    ///
20    /// most of the case, a tensor can be reshaped if the dimension going to be reshaped is `contiguous`
21    ///
22    /// # Case
23    ///
24    /// if the shape is `[2, 3, 4, 5]` and the strides is `[60, 20, 5, 1]`, they are `contiguous`
25    ///
26    /// if we permute the shape to `[2, 3, 5, 4]` and the strides to `[60, 20, 1, 5]`, they are not `contiguous` generally, but the first two dimensions are `contiguous`
27    ///
28    /// if reshape is on the first two dimensions, like `[2, 3, 5, 4]` to `[6, 4, 5]`, or `[2, 3, 5, 4]` to `[3, 2, 5, 4]`, they are possible
29    ///
30    /// if reshape is on the last two dimensions, like `[2, 3, 5, 4]` to `[2, 3, 20]`, it is not possible
31    ///
32    /// # Arguments
33    ///
34    /// * `shape` - the new shape to be reshaped
35    ///
36    /// # Returns
37    ///
38    /// * `Option<Strides>` - if the reshape is possible, return the new strides, otherwise return `None`
39    ///
40    /// # Examples
41    ///
42    /// ```
43    /// use hpt_common::layout::Layout;
44    /// use hpt_common::shape::Shape;
45    /// use hpt_common::strides::strides::Strides;
46    ///
47    /// let shape = Shape::from(vec![2, 3, 4]);
48    /// let strides = strides::strides::from(vec![12, 4, 1]);
49    ///
50    /// let layout = Layout::new(shape.clone(), strides.clone());
51    ///
52    /// let new_shape = Shape::from(vec![3, 2, 4]);
53    /// let new_strides = layout.is_reshape_possible(&new_shape).unwrap();
54    ///
55    /// assert_eq!(new_strides, strides::strides::from(vec![8, 4, 1]));
56    /// ```
57    pub fn is_reshape_possible(&self, shape: &[i64]) -> Option<Strides> {
58        if self.size() != shape.iter().product::<i64>() {
59            return None;
60        }
61        is_reshape_possible(&self.shape, &self.strides, shape)
62    }
63
64    /// # Internal Function
65    ///
66    /// a function use to calculate the broadcast layout based on the target shape
67    ///
68    /// # Arguments
69    ///
70    /// * `shape` - the target shape
71    pub fn to_broadcast_layout(&self, target_shape: &[i64]) -> Result<Layout, TensorError> {
72        let padded_shape = try_pad_shape(&self.shape, target_shape.len());
73
74        // try_pad_shape can also used on strides
75        let padded_strides = try_pad_shape(&self.strides, target_shape.len());
76        let mut new_strides = vec![0; target_shape.len()];
77
78        for i in 0..target_shape.len() {
79            if padded_shape[i] == target_shape[i] {
80                new_strides[i] = padded_strides[i];
81            } else {
82                new_strides[i] = 0;
83            }
84        }
85
86        Ok(Layout {
87            shape: target_shape.into(),
88            strides: new_strides.into(),
89        })
90    }
91
92    /// # Internal Function
93    ///
94    /// a function mainly use for expanding the strides of a tensor
95    ///
96    /// this function is simply convert the stride of the dimension going to be expanded to `0`
97    ///
98    /// # Arguments
99    ///
100    /// * `expand_shape` - the new shape to be expanded
101    ///
102    /// # Returns
103    ///
104    /// * `Result<Strides>` - the new strides after expanding
105    pub fn expand_strides(&self, expand_shape: &[i64]) -> Result<Strides, TensorError> {
106        let mut res_strides = vec![0; expand_shape.len()];
107        for (((idx, new_dim), old_dim), old_stride) in expand_shape
108            .iter()
109            .enumerate()
110            .rev()
111            .zip(self.shape.iter().rev())
112            .zip(self.strides.iter().rev())
113        {
114            if new_dim != old_dim && old_dim == &1 {
115                res_strides[idx] = 0;
116            } else if new_dim != old_dim && old_dim != &1 {
117                return Err(ShapeError::ExpandError {
118                    old_dim: *old_dim,
119                    location: Location::caller(),
120                }
121                .into());
122            } else {
123                res_strides[idx] = *old_stride;
124            }
125        }
126        Ok(res_strides.into())
127    }
128
129    /// # Internal Function
130    /// a function mainly use for calculating the real size of a tensor
131    /// pretty useful when the tensor is a view of another tensor
132    pub fn real_size(&self) -> usize {
133        assert_eq!(self.shape.len(), self.strides.len());
134        let mut max_stride = 0;
135        let mut max_idx = 0;
136        for (idx, stride) in self.strides.iter().enumerate() {
137            if *stride > max_stride {
138                max_stride = *stride;
139                max_idx = idx;
140            }
141        }
142        (self.shape[max_idx] * max_stride) as usize
143    }
144
145    /// # Internal Function
146    ///
147    /// a function use to calculate the permuted layout
148    ///
149    /// # Arguments
150    ///
151    /// * `axes` - the new order of the dimensions
152    ///
153    /// # Returns
154    ///
155    /// * `Result<Layout>` - the new layout after permutation
156    ///
157    /// # Panics
158    ///
159    /// if the length of `axes` is not equal to the layout's ndim
160    #[track_caller]
161    pub fn permute<A: Into<Axis>>(&self, axes: A) -> Result<Layout, TensorError> {
162        let axes = process_axes(axes, self.shape.len())?;
163        ShapeError::check_dim(axes.len(), self.shape.len())?;
164        let mut new_shape = self.shape().to_vec();
165        let mut new_strides = self.strides().to_vec();
166        for i in axes.iter() {
167            new_shape[*i] = self.shape()[axes[*i]];
168            new_strides[*i] = self.strides()[axes[*i]];
169        }
170        Ok(Layout {
171            shape: new_shape.into(),
172            strides: new_strides.into(),
173        })
174    }
175
176    /// # Internal Function
177    ///
178    /// a function use to calculate the inverse permuted layout
179    ///
180    /// # Arguments
181    ///
182    /// * `axes` - the new order of the dimensions
183    ///
184    /// # Returns
185    ///
186    /// * `Result<Layout>` - the new layout after inverse permutation
187    pub fn permute_inv<A: Into<Axis>>(&self, axes: A) -> Result<Layout, TensorError> {
188        let axes = process_axes(axes, self.shape.len())?;
189        ShapeError::check_dim(axes.len(), self.shape.len())?;
190        let mut new_shape = self.shape().to_vec();
191        let mut new_strides = self.strides().to_vec();
192        for i in axes.iter() {
193            new_shape[axes[*i]] = self.shape()[*i];
194            new_strides[axes[*i]] = self.strides()[*i];
195        }
196        Ok(Layout {
197            shape: new_shape.into(),
198            strides: new_strides.into(),
199        })
200    }
201
202    /// # Internal Function
203    ///
204    /// perform an inplace reshape on the layout
205    ///
206    /// # Arguments
207    ///
208    /// * `shape` - the new shape to be reshaped
209    ///
210    /// # Returns
211    ///
212    /// * `Result<Layout>` - the new layout after reshape
213    ///
214    /// # Panics
215    ///
216    /// if the reshape is not possible
217    #[track_caller]
218    pub fn inplace_reshape(&self, shape: &Shape) -> Result<Layout, TensorError> {
219        if let Some(new_strides) = self.is_reshape_possible(shape) {
220            Ok(Layout {
221                shape: shape.clone(),
222                strides: new_strides,
223            })
224        } else {
225            Err(ShapeError::InplaceReshapeError {
226                message: "Inplace reshape is not possible".to_string(),
227                location: Location::caller(),
228            }
229            .into())
230        }
231    }
232
233    /// # Internal Function
234    ///
235    /// broadcast the layout to another layout
236    ///
237    /// # Arguments
238    ///
239    /// * `other` - the other layout to be broadcasted
240    ///
241    /// # Returns
242    ///
243    /// * `Result<Layout>` - the new layout after broadcast
244    ///
245    /// # Panics
246    ///
247    /// if the broadcast is not possible
248    #[track_caller]
249    pub fn broadcast(&self, other: &Layout) -> Result<Layout, TensorError> {
250        let shape = predict_broadcast_shape(&self.shape, &other.shape)?;
251        let strides = shape_to_strides(&shape);
252        Ok(Layout { shape, strides })
253    }
254
255    /// # Internal Function
256    ///
257    /// reduce the layout to another layout
258    ///
259    /// this is mainly used for reducing the dimension of a tensor
260    ///
261    /// # Arguments
262    ///
263    /// * `axes` - the axes to be reduced
264    ///
265    /// * `keep_dims` - whether to keep the reduced dimensions
266    ///
267    /// # Returns
268    ///
269    /// * `Result<Layout>` - the new layout after reduction
270    ///
271    /// # Panics
272    ///
273    /// if the `axes` contains the same axis
274    ///
275    /// if the `axes` contains the same axis as the layout's ndim
276    ///
277    /// if the `axes` contains the axis out of range
278    pub fn reduce<A: Into<Axis>>(&self, axes: A, keep_dims: bool) -> Result<Layout, TensorError> {
279        let a: Axis = axes.into();
280        let axis = process_axes(a, self.shape.len())?;
281        let new_shape = if keep_dims {
282            let mut vec = Vec::with_capacity(self.shape.len());
283            for i in 0..self.shape.len() {
284                if axis.contains(&i) {
285                    vec.push(1);
286                } else {
287                    vec.push(self.shape[i]);
288                }
289            }
290            vec
291        } else {
292            let mut vec = Vec::with_capacity(self.shape.len() - axis.len());
293            for i in 0..self.shape.len() {
294                if !axis.contains(&i) {
295                    vec.push(self.shape[i]);
296                }
297            }
298            vec
299        };
300        if new_shape.len() > 0 {
301            let new_strides = shape_to_strides(&new_shape);
302            Ok(Layout {
303                shape: new_shape.into(),
304                strides: new_strides,
305            })
306        } else {
307            Ok(Layout {
308                shape: vec![1].into(),
309                strides: vec![1].into(),
310            })
311        }
312    }
313
314    /// simply return the product of the shape
315    ///
316    /// # Safety
317    ///
318    /// when the layout is a view of another layout, the size will be different, this method won't work
319    #[inline(always)]
320    pub fn size(&self) -> i64 {
321        self.shape.iter().product::<i64>()
322    }
323
324    /// # Internal Function
325    ///
326    /// check if the layout is contiguous
327    ///
328    /// # Returns
329    ///
330    /// * `bool` - whether the layout is contiguous
331    pub fn is_contiguous(&self) -> bool {
332        let mut expected_stride = 1;
333        for (&dim_size, &stride) in self.shape.iter().rev().zip(self.strides.iter().rev()) {
334            if dim_size == 0 {
335                continue;
336            }
337            if stride != expected_stride {
338                return false;
339            }
340            expected_stride *= dim_size;
341        }
342        true
343    }
344
345    /// # Internal Function
346    ///
347    /// coalesce the dimensions of a layout
348    ///
349    /// # Returns
350    ///
351    /// * `Vec<Vec<usize>>` - the coalesced dimensions
352    pub fn coalesce_dims(&self) -> Vec<Vec<usize>> {
353        let shape = &self.shape;
354        let strides = &self.strides;
355        let mut groups = vec![vec![shape.len() - 1]];
356        let mut current_stride = strides[shape.len() - 1];
357        let mut current_size = shape[shape.len() - 1];
358
359        for i in (0..shape.len() - 1).rev() {
360            let expected_stride = current_stride * current_size;
361
362            if strides[i] == expected_stride {
363                groups.last_mut().unwrap().push(i);
364            } else {
365                groups.push(vec![i]);
366            }
367
368            current_stride = strides[i];
369            current_size = shape[i];
370        }
371
372        for group in groups.iter_mut() {
373            group.reverse();
374        }
375        groups.reverse();
376
377        groups
378    }
379}