hpt_common/layout/layout_utils.rs
1use std::panic::Location;
2
3use crate::{
4 axis::axis::{process_axes, Axis},
5 error::{base::TensorError, shape::ShapeError},
6 shape::{
7 shape::Shape,
8 shape_utils::{is_reshape_possible, predict_broadcast_shape, try_pad_shape},
9 },
10 strides::{strides::Strides, strides_utils::shape_to_strides},
11};
12
13use super::layout::Layout;
14
15impl Layout {
16 /// # Internal Function
17 ///
18 /// a function mainly use for checking if the reshape is possible
19 ///
20 /// most of the case, a tensor can be reshaped if the dimension going to be reshaped is `contiguous`
21 ///
22 /// # Case
23 ///
24 /// if the shape is `[2, 3, 4, 5]` and the strides is `[60, 20, 5, 1]`, they are `contiguous`
25 ///
26 /// if we permute the shape to `[2, 3, 5, 4]` and the strides to `[60, 20, 1, 5]`, they are not `contiguous` generally, but the first two dimensions are `contiguous`
27 ///
28 /// if reshape is on the first two dimensions, like `[2, 3, 5, 4]` to `[6, 4, 5]`, or `[2, 3, 5, 4]` to `[3, 2, 5, 4]`, they are possible
29 ///
30 /// if reshape is on the last two dimensions, like `[2, 3, 5, 4]` to `[2, 3, 20]`, it is not possible
31 ///
32 /// # Arguments
33 ///
34 /// * `shape` - the new shape to be reshaped
35 ///
36 /// # Returns
37 ///
38 /// * `Option<Strides>` - if the reshape is possible, return the new strides, otherwise return `None`
39 ///
40 /// # Examples
41 ///
42 /// ```
43 /// use hpt_common::layout::Layout;
44 /// use hpt_common::shape::Shape;
45 /// use hpt_common::strides::strides::Strides;
46 ///
47 /// let shape = Shape::from(vec![2, 3, 4]);
48 /// let strides = strides::strides::from(vec![12, 4, 1]);
49 ///
50 /// let layout = Layout::new(shape.clone(), strides.clone());
51 ///
52 /// let new_shape = Shape::from(vec![3, 2, 4]);
53 /// let new_strides = layout.is_reshape_possible(&new_shape).unwrap();
54 ///
55 /// assert_eq!(new_strides, strides::strides::from(vec![8, 4, 1]));
56 /// ```
57 pub fn is_reshape_possible(&self, shape: &[i64]) -> Option<Strides> {
58 if self.size() != shape.iter().product::<i64>() {
59 return None;
60 }
61 is_reshape_possible(&self.shape, &self.strides, shape)
62 }
63
64 /// # Internal Function
65 ///
66 /// a function use to calculate the broadcast layout based on the target shape
67 ///
68 /// # Arguments
69 ///
70 /// * `shape` - the target shape
71 pub fn to_broadcast_layout(&self, target_shape: &[i64]) -> Result<Layout, TensorError> {
72 let padded_shape = try_pad_shape(&self.shape, target_shape.len());
73
74 // try_pad_shape can also used on strides
75 let padded_strides = try_pad_shape(&self.strides, target_shape.len());
76 let mut new_strides = vec![0; target_shape.len()];
77
78 for i in 0..target_shape.len() {
79 if padded_shape[i] == target_shape[i] {
80 new_strides[i] = padded_strides[i];
81 } else {
82 new_strides[i] = 0;
83 }
84 }
85
86 Ok(Layout {
87 shape: target_shape.into(),
88 strides: new_strides.into(),
89 })
90 }
91
92 /// # Internal Function
93 ///
94 /// a function mainly use for expanding the strides of a tensor
95 ///
96 /// this function is simply convert the stride of the dimension going to be expanded to `0`
97 ///
98 /// # Arguments
99 ///
100 /// * `expand_shape` - the new shape to be expanded
101 ///
102 /// # Returns
103 ///
104 /// * `Result<Strides>` - the new strides after expanding
105 pub fn expand_strides(&self, expand_shape: &[i64]) -> Result<Strides, TensorError> {
106 let mut res_strides = vec![0; expand_shape.len()];
107 for (((idx, new_dim), old_dim), old_stride) in expand_shape
108 .iter()
109 .enumerate()
110 .rev()
111 .zip(self.shape.iter().rev())
112 .zip(self.strides.iter().rev())
113 {
114 if new_dim != old_dim && old_dim == &1 {
115 res_strides[idx] = 0;
116 } else if new_dim != old_dim && old_dim != &1 {
117 return Err(ShapeError::ExpandError {
118 old_dim: *old_dim,
119 location: Location::caller(),
120 }
121 .into());
122 } else {
123 res_strides[idx] = *old_stride;
124 }
125 }
126 Ok(res_strides.into())
127 }
128
129 /// # Internal Function
130 /// a function mainly use for calculating the real size of a tensor
131 /// pretty useful when the tensor is a view of another tensor
132 pub fn real_size(&self) -> usize {
133 assert_eq!(self.shape.len(), self.strides.len());
134 let mut max_stride = 0;
135 let mut max_idx = 0;
136 for (idx, stride) in self.strides.iter().enumerate() {
137 if *stride > max_stride {
138 max_stride = *stride;
139 max_idx = idx;
140 }
141 }
142 (self.shape[max_idx] * max_stride) as usize
143 }
144
145 /// # Internal Function
146 ///
147 /// a function use to calculate the permuted layout
148 ///
149 /// # Arguments
150 ///
151 /// * `axes` - the new order of the dimensions
152 ///
153 /// # Returns
154 ///
155 /// * `Result<Layout>` - the new layout after permutation
156 ///
157 /// # Panics
158 ///
159 /// if the length of `axes` is not equal to the layout's ndim
160 #[track_caller]
161 pub fn permute<A: Into<Axis>>(&self, axes: A) -> Result<Layout, TensorError> {
162 let axes = process_axes(axes, self.shape.len())?;
163 ShapeError::check_dim(axes.len(), self.shape.len())?;
164 let mut new_shape = self.shape().to_vec();
165 let mut new_strides = self.strides().to_vec();
166 for i in axes.iter() {
167 new_shape[*i] = self.shape()[axes[*i]];
168 new_strides[*i] = self.strides()[axes[*i]];
169 }
170 Ok(Layout {
171 shape: new_shape.into(),
172 strides: new_strides.into(),
173 })
174 }
175
176 /// # Internal Function
177 ///
178 /// a function use to calculate the inverse permuted layout
179 ///
180 /// # Arguments
181 ///
182 /// * `axes` - the new order of the dimensions
183 ///
184 /// # Returns
185 ///
186 /// * `Result<Layout>` - the new layout after inverse permutation
187 pub fn permute_inv<A: Into<Axis>>(&self, axes: A) -> Result<Layout, TensorError> {
188 let axes = process_axes(axes, self.shape.len())?;
189 ShapeError::check_dim(axes.len(), self.shape.len())?;
190 let mut new_shape = self.shape().to_vec();
191 let mut new_strides = self.strides().to_vec();
192 for i in axes.iter() {
193 new_shape[axes[*i]] = self.shape()[*i];
194 new_strides[axes[*i]] = self.strides()[*i];
195 }
196 Ok(Layout {
197 shape: new_shape.into(),
198 strides: new_strides.into(),
199 })
200 }
201
202 /// # Internal Function
203 ///
204 /// perform an inplace reshape on the layout
205 ///
206 /// # Arguments
207 ///
208 /// * `shape` - the new shape to be reshaped
209 ///
210 /// # Returns
211 ///
212 /// * `Result<Layout>` - the new layout after reshape
213 ///
214 /// # Panics
215 ///
216 /// if the reshape is not possible
217 #[track_caller]
218 pub fn inplace_reshape(&self, shape: &Shape) -> Result<Layout, TensorError> {
219 if let Some(new_strides) = self.is_reshape_possible(shape) {
220 Ok(Layout {
221 shape: shape.clone(),
222 strides: new_strides,
223 })
224 } else {
225 Err(ShapeError::InplaceReshapeError {
226 message: "Inplace reshape is not possible".to_string(),
227 location: Location::caller(),
228 }
229 .into())
230 }
231 }
232
233 /// # Internal Function
234 ///
235 /// broadcast the layout to another layout
236 ///
237 /// # Arguments
238 ///
239 /// * `other` - the other layout to be broadcasted
240 ///
241 /// # Returns
242 ///
243 /// * `Result<Layout>` - the new layout after broadcast
244 ///
245 /// # Panics
246 ///
247 /// if the broadcast is not possible
248 #[track_caller]
249 pub fn broadcast(&self, other: &Layout) -> Result<Layout, TensorError> {
250 let shape = predict_broadcast_shape(&self.shape, &other.shape)?;
251 let strides = shape_to_strides(&shape);
252 Ok(Layout { shape, strides })
253 }
254
255 /// # Internal Function
256 ///
257 /// reduce the layout to another layout
258 ///
259 /// this is mainly used for reducing the dimension of a tensor
260 ///
261 /// # Arguments
262 ///
263 /// * `axes` - the axes to be reduced
264 ///
265 /// * `keep_dims` - whether to keep the reduced dimensions
266 ///
267 /// # Returns
268 ///
269 /// * `Result<Layout>` - the new layout after reduction
270 ///
271 /// # Panics
272 ///
273 /// if the `axes` contains the same axis
274 ///
275 /// if the `axes` contains the same axis as the layout's ndim
276 ///
277 /// if the `axes` contains the axis out of range
278 pub fn reduce<A: Into<Axis>>(&self, axes: A, keep_dims: bool) -> Result<Layout, TensorError> {
279 let a: Axis = axes.into();
280 let axis = process_axes(a, self.shape.len())?;
281 let new_shape = if keep_dims {
282 let mut vec = Vec::with_capacity(self.shape.len());
283 for i in 0..self.shape.len() {
284 if axis.contains(&i) {
285 vec.push(1);
286 } else {
287 vec.push(self.shape[i]);
288 }
289 }
290 vec
291 } else {
292 let mut vec = Vec::with_capacity(self.shape.len() - axis.len());
293 for i in 0..self.shape.len() {
294 if !axis.contains(&i) {
295 vec.push(self.shape[i]);
296 }
297 }
298 vec
299 };
300 if new_shape.len() > 0 {
301 let new_strides = shape_to_strides(&new_shape);
302 Ok(Layout {
303 shape: new_shape.into(),
304 strides: new_strides,
305 })
306 } else {
307 Ok(Layout {
308 shape: vec![1].into(),
309 strides: vec![1].into(),
310 })
311 }
312 }
313
314 /// simply return the product of the shape
315 ///
316 /// # Safety
317 ///
318 /// when the layout is a view of another layout, the size will be different, this method won't work
319 #[inline(always)]
320 pub fn size(&self) -> i64 {
321 self.shape.iter().product::<i64>()
322 }
323
324 /// # Internal Function
325 ///
326 /// check if the layout is contiguous
327 ///
328 /// # Returns
329 ///
330 /// * `bool` - whether the layout is contiguous
331 pub fn is_contiguous(&self) -> bool {
332 let mut expected_stride = 1;
333 for (&dim_size, &stride) in self.shape.iter().rev().zip(self.strides.iter().rev()) {
334 if dim_size == 0 {
335 continue;
336 }
337 if stride != expected_stride {
338 return false;
339 }
340 expected_stride *= dim_size;
341 }
342 true
343 }
344
345 /// # Internal Function
346 ///
347 /// coalesce the dimensions of a layout
348 ///
349 /// # Returns
350 ///
351 /// * `Vec<Vec<usize>>` - the coalesced dimensions
352 pub fn coalesce_dims(&self) -> Vec<Vec<usize>> {
353 let shape = &self.shape;
354 let strides = &self.strides;
355 let mut groups = vec![vec![shape.len() - 1]];
356 let mut current_stride = strides[shape.len() - 1];
357 let mut current_size = shape[shape.len() - 1];
358
359 for i in (0..shape.len() - 1).rev() {
360 let expected_stride = current_stride * current_size;
361
362 if strides[i] == expected_stride {
363 groups.last_mut().unwrap().push(i);
364 } else {
365 groups.push(vec![i]);
366 }
367
368 current_stride = strides[i];
369 current_size = shape[i];
370 }
371
372 for group in groups.iter_mut() {
373 group.reverse();
374 }
375 groups.reverse();
376
377 groups
378 }
379}