hpt_common/strides/strides_utils.rs
1use crate::strides::strides::Strides;
2
3/// # Internal Function
4/// Preprocesses strides based on the given shape.
5///
6/// This function adjusts the strides of a tensor according to its shape.
7/// Strides corresponding to dimensions with size 1 are set to 0.
8pub fn preprocess_strides(shape: &[i64], stride: &[i64]) -> Vec<i64> {
9 let mut strides = vec![0; shape.len()];
10 let start = shape.len() - stride.len();
11
12 for i in 0..stride.len() {
13 if shape[start + i] != 1i64 {
14 strides[start + i] = stride[i];
15 }
16 }
17 strides
18}
19
20/// # Internal Function
21/// Converts a shape to strides.
22///
23/// This function calculates the strides of a tensor based on its shape,
24/// assuming a contiguous memory layout.
25///
26/// # Arguments
27/// - `shape`: A reference to the shape of the tensor.
28///
29/// # Returns
30/// A `Vec<i64>` representing the strides calculated from the shape.
31pub fn shape_to_strides(shape: &[i64]) -> Strides {
32 let mut strides = vec![0; shape.len()];
33 let mut size = 1;
34 for i in (0..shape.len()).rev() {
35 let tmp = shape[i];
36 strides[i] = size;
37 size *= tmp;
38 }
39 strides.into()
40}