hpt_common/strides/
strides_utils.rs

1use crate::strides::strides::Strides;
2
3/// # Internal Function
4/// Preprocesses strides based on the given shape.
5///
6/// This function adjusts the strides of a tensor according to its shape.
7/// Strides corresponding to dimensions with size 1 are set to 0.
8pub fn preprocess_strides(shape: &[i64], stride: &[i64]) -> Vec<i64> {
9    let mut strides = vec![0; shape.len()];
10    let start = shape.len() - stride.len();
11
12    for i in 0..stride.len() {
13        if shape[start + i] != 1i64 {
14            strides[start + i] = stride[i];
15        }
16    }
17    strides
18}
19
20/// # Internal Function
21/// Converts a shape to strides.
22///
23/// This function calculates the strides of a tensor based on its shape,
24/// assuming a contiguous memory layout.
25///
26/// # Arguments
27/// - `shape`: A reference to the shape of the tensor.
28///
29/// # Returns
30/// A `Vec<i64>` representing the strides calculated from the shape.
31pub fn shape_to_strides(shape: &[i64]) -> Strides {
32    let mut strides = vec![0; shape.len()];
33    let mut size = 1;
34    for i in (0..shape.len()).rev() {
35        let tmp = shape[i];
36        strides[i] = size;
37        size *= tmp;
38    }
39    strides.into()
40}