hpt_common/
slice.rs

1use crate::error::base::TensorError;
2
3/// # Internal Function
4/// Processes tensor slicing with given strides and shape, adjusting strides and shape
5/// based on the slicing operation and applying an additional scaling factor `alpha`.
6///
7/// This function performs slicing operations on a tensor's shape and strides according to
8/// the provided `index` and scales both the shape and strides by a factor of `alpha`.
9///
10/// # Arguments
11/// - `shape`: A `Vec<i64>` representing the shape of the tensor.
12/// - `strides`: A `Vec<i64>` representing the original strides of the tensor.
13/// - `index`: A slice of `Slice` enums that specify the slicing operations to apply to each dimension.
14/// - `alpha`: A scaling factor of type `i64` that is applied to both the shape and strides.
15///
16/// # Returns
17/// This function returns a `Result` with the following tuple upon success:
18/// - `Vec<i64>`: The new shape of the tensor after applying the slicing and scaling.
19/// - `Vec<i64>`: The new strides after applying the slicing and scaling.
20/// - `i64`: The adjusted pointer offset based on the slicing.
21///
22/// If the `index` length is out of range for the given `shape`, it returns an error.
23///
24/// # Errors
25/// - Returns an error if the `index` length exceeds the number of dimensions in the tensor shape.
26/// - Returns an error if a slicing operation goes out of the bounds of the tensor's shape.
27///
28/// # Examples
29/// ```
30/// use hpt_common::slice_process;
31/// use hpt_types::Slice;
32///
33/// let shape = vec![3, 4, 5];
34/// let strides = vec![20, 5, 1];
35/// let index = vec![Slice::From(1), Slice::Range((0, 3)), Slice::StepByFullRange(2)];
36/// let alpha = 1;
37/// let result = slice_process(shape, strides, &index, alpha).unwrap();
38/// assert_eq!(result, (vec![2, 3, 3], vec![20, 5, 2], 20));
39/// ```
40#[track_caller]
41pub fn slice_process(
42    shape: Vec<i64>,
43    strides: Vec<i64>,
44    index: &[(i64, i64, i64)],
45    alpha: i64,
46) -> std::result::Result<(Vec<i64>, Vec<i64>, i64), TensorError> {
47    let mut res_shape: Vec<i64> = shape.clone();
48    let mut res_strides: Vec<i64> = strides.clone();
49    res_shape.iter_mut().for_each(|x| {
50        *x *= alpha;
51    });
52    res_strides.iter_mut().for_each(|x| {
53        *x *= alpha;
54    });
55    let mut res_ptr = 0;
56    if index.len() > res_shape.len() {
57        panic!("index length is greater than the shape length");
58    }
59    let mut new_indices = Vec::with_capacity(shape.len());
60    let ellipsis_pos = index
61        .iter()
62        .position(|&idx| idx == (0, 0, 0x7FFFFFFFFFFFFFFF));
63    if let Some(pos) = ellipsis_pos {
64        let missing_dims = shape.len() - (index.len() - 1);
65        new_indices.extend_from_slice(&index[0..pos]);
66        for _ in 0..missing_dims {
67            new_indices.push((0, 0x7FFFFFFFFFFFFFFF, 1));
68        }
69        new_indices.extend_from_slice(&index[pos + 1..]);
70    } else {
71        new_indices = index.to_vec();
72    }
73
74    for (idx, (start, mut end, step)) in new_indices.into_iter().enumerate() {
75        if end == 0x7FFFFFFFFFFFFFFF {
76            end = shape[idx];
77        }
78        let mut start = if start >= 0 {
79            start
80        } else {
81            start + shape[idx]
82        };
83        let mut end = if end >= 0 { end } else { end + shape[idx] };
84
85        if start >= shape[idx] {
86            start = shape[idx] - 1;
87        }
88        if end > shape[idx] {
89            end = shape[idx];
90        }
91
92        let length = if step > 0 {
93            (end - start + step - 1) / step
94        } else if step < 0 {
95            (end - start + step + 1) / step
96        } else {
97            0
98        };
99
100        if length > 0 {
101            res_shape[idx] = length * alpha;
102            res_ptr += start * res_strides[idx];
103            res_strides[idx] *= step;
104        } else {
105            res_shape[idx] = 0;
106        }
107    }
108    let mut new_shape = Vec::new();
109    let mut new_strides = Vec::new();
110    for (i, &s) in res_shape.iter().enumerate() {
111        if s == 0 {
112            continue;
113        }
114        new_shape.push(s);
115        new_strides.push(res_strides[i]);
116    }
117    Ok((new_shape, new_strides, res_ptr))
118}