hpt_common/slice.rs
1use crate::error::base::TensorError;
2
3/// # Internal Function
4/// Processes tensor slicing with given strides and shape, adjusting strides and shape
5/// based on the slicing operation and applying an additional scaling factor `alpha`.
6///
7/// This function performs slicing operations on a tensor's shape and strides according to
8/// the provided `index` and scales both the shape and strides by a factor of `alpha`.
9///
10/// # Arguments
11/// - `shape`: A `Vec<i64>` representing the shape of the tensor.
12/// - `strides`: A `Vec<i64>` representing the original strides of the tensor.
13/// - `index`: A slice of `Slice` enums that specify the slicing operations to apply to each dimension.
14/// - `alpha`: A scaling factor of type `i64` that is applied to both the shape and strides.
15///
16/// # Returns
17/// This function returns a `Result` with the following tuple upon success:
18/// - `Vec<i64>`: The new shape of the tensor after applying the slicing and scaling.
19/// - `Vec<i64>`: The new strides after applying the slicing and scaling.
20/// - `i64`: The adjusted pointer offset based on the slicing.
21///
22/// If the `index` length is out of range for the given `shape`, it returns an error.
23///
24/// # Errors
25/// - Returns an error if the `index` length exceeds the number of dimensions in the tensor shape.
26/// - Returns an error if a slicing operation goes out of the bounds of the tensor's shape.
27///
28/// # Examples
29/// ```
30/// use hpt_common::slice_process;
31/// use hpt_types::Slice;
32///
33/// let shape = vec![3, 4, 5];
34/// let strides = vec![20, 5, 1];
35/// let index = vec![Slice::From(1), Slice::Range((0, 3)), Slice::StepByFullRange(2)];
36/// let alpha = 1;
37/// let result = slice_process(shape, strides, &index, alpha).unwrap();
38/// assert_eq!(result, (vec![2, 3, 3], vec![20, 5, 2], 20));
39/// ```
40#[track_caller]
41pub fn slice_process(
42 shape: Vec<i64>,
43 strides: Vec<i64>,
44 index: &[(i64, i64, i64)],
45 alpha: i64,
46) -> std::result::Result<(Vec<i64>, Vec<i64>, i64), TensorError> {
47 let mut res_shape: Vec<i64> = shape.clone();
48 let mut res_strides: Vec<i64> = strides.clone();
49 res_shape.iter_mut().for_each(|x| {
50 *x *= alpha;
51 });
52 res_strides.iter_mut().for_each(|x| {
53 *x *= alpha;
54 });
55 let mut res_ptr = 0;
56 if index.len() > res_shape.len() {
57 panic!("index length is greater than the shape length");
58 }
59 let mut new_indices = Vec::with_capacity(shape.len());
60 let ellipsis_pos = index
61 .iter()
62 .position(|&idx| idx == (0, 0, 0x7FFFFFFFFFFFFFFF));
63 if let Some(pos) = ellipsis_pos {
64 let missing_dims = shape.len() - (index.len() - 1);
65 new_indices.extend_from_slice(&index[0..pos]);
66 for _ in 0..missing_dims {
67 new_indices.push((0, 0x7FFFFFFFFFFFFFFF, 1));
68 }
69 new_indices.extend_from_slice(&index[pos + 1..]);
70 } else {
71 new_indices = index.to_vec();
72 }
73
74 for (idx, (start, mut end, step)) in new_indices.into_iter().enumerate() {
75 if end == 0x7FFFFFFFFFFFFFFF {
76 end = shape[idx];
77 }
78 let mut start = if start >= 0 {
79 start
80 } else {
81 start + shape[idx]
82 };
83 let mut end = if end >= 0 { end } else { end + shape[idx] };
84
85 if start >= shape[idx] {
86 start = shape[idx] - 1;
87 }
88 if end > shape[idx] {
89 end = shape[idx];
90 }
91
92 let length = if step > 0 {
93 (end - start + step - 1) / step
94 } else if step < 0 {
95 (end - start + step + 1) / step
96 } else {
97 0
98 };
99
100 if length > 0 {
101 res_shape[idx] = length * alpha;
102 res_ptr += start * res_strides[idx];
103 res_strides[idx] *= step;
104 } else {
105 res_shape[idx] = 0;
106 }
107 }
108 let mut new_shape = Vec::new();
109 let mut new_strides = Vec::new();
110 for (i, &s) in res_shape.iter().enumerate() {
111 if s == 0 {
112 continue;
113 }
114 new_shape.push(s);
115 new_strides.push(res_strides[i]);
116 }
117 Ok((new_shape, new_strides, res_ptr))
118}