hpt_common/slice.rs
1use crate::error::{base::TensorError, shape::ShapeError};
2
3/// Slice enum to hold the slice information
4///
5/// it stores the slice information the user wants to perform operations on
6///
7/// it is not being used directly by the user, but is used by the library internally
8#[derive(Debug, Clone)]
9pub enum Slice {
10 /// load the element at the index
11 From(i64),
12 /// load all the elements along the corresponding dimension
13 Full,
14 /// load from the first element to the end along the corresponding dimension
15 RangeFrom(i64),
16 /// load from the beginning to specified index along the corresponding dimension
17 RangeTo(i64),
18 /// load from the start index to the end index along the corresponding dimension
19 Range((i64, i64)),
20 /// load from the start index to the end index with step along the corresponding dimension
21 StepByRangeFrom((i64, i64)),
22 /// load all the elements with step along the corresponding dimension
23 StepByFullRange(i64),
24 /// load from the start index to the end index with step along the corresponding dimension
25 StepByRangeFromTo((i64, i64, i64)),
26 /// load from the start index to the end index with step along the corresponding dimension
27 StepByRangeTo((i64, i64)),
28}
29
30/// # Internal Function
31/// Processes tensor slicing with given strides and shape, adjusting strides and shape
32/// based on the slicing operation and applying an additional scaling factor `alpha`.
33///
34/// This function performs slicing operations on a tensor's shape and strides according to
35/// the provided `index` and scales both the shape and strides by a factor of `alpha`.
36///
37/// # Arguments
38/// - `shape`: A `Vec<i64>` representing the shape of the tensor.
39/// - `strides`: A `Vec<i64>` representing the original strides of the tensor.
40/// - `index`: A slice of `Slice` enums that specify the slicing operations to apply to each dimension.
41/// - `alpha`: A scaling factor of type `i64` that is applied to both the shape and strides.
42///
43/// # Returns
44/// This function returns a `Result` with the following tuple upon success:
45/// - `Vec<i64>`: The new shape of the tensor after applying the slicing and scaling.
46/// - `Vec<i64>`: The new strides after applying the slicing and scaling.
47/// - `i64`: The adjusted pointer offset based on the slicing.
48///
49/// If the `index` length is out of range for the given `shape`, it returns an error.
50///
51/// # Errors
52/// - Returns an error if the `index` length exceeds the number of dimensions in the tensor shape.
53/// - Returns an error if a slicing operation goes out of the bounds of the tensor's shape.
54///
55/// # Examples
56/// ```
57/// use hpt_common::slice_process;
58/// use hpt_types::Slice;
59///
60/// let shape = vec![3, 4, 5];
61/// let strides = vec![20, 5, 1];
62/// let index = vec![Slice::From(1), Slice::Range((0, 3)), Slice::StepByFullRange(2)];
63/// let alpha = 1;
64/// let result = slice_process(shape, strides, &index, alpha).unwrap();
65/// assert_eq!(result, (vec![2, 3, 3], vec![20, 5, 2], 20));
66/// ```
67#[track_caller]
68pub fn slice_process(
69 shape: Vec<i64>,
70 strides: Vec<i64>,
71 index: &[Slice],
72 alpha: i64,
73) -> std::result::Result<(Vec<i64>, Vec<i64>, i64), TensorError> {
74 let mut res_shape: Vec<i64> = shape.clone();
75 let mut res_strides: Vec<i64> = strides.clone();
76 res_shape.iter_mut().for_each(|x| {
77 *x *= alpha;
78 });
79 res_strides.iter_mut().for_each(|x| {
80 *x *= alpha;
81 });
82 let mut res_ptr = 0;
83 if index.len() > res_shape.len() {
84 panic!("index length is greater than the shape length");
85 }
86 for (idx, slice) in index.iter().enumerate() {
87 match slice {
88 Slice::From(mut __index) => {
89 let mut index;
90 if __index >= 0 {
91 index = __index;
92 } else {
93 index = __index + shape[idx];
94 }
95 index *= alpha;
96 ShapeError::check_index_out_of_range(index, shape[idx])?;
97 res_shape[idx] = alpha;
98 res_ptr += res_strides[idx] * index;
99 }
100 // tested
101 Slice::RangeFrom(mut __index) => {
102 let index = if __index >= 0 {
103 __index
104 } else {
105 __index + shape[idx]
106 };
107 let length = (shape[idx] - index) * alpha;
108 res_shape[idx] = if length > 0 { length } else { 0 };
109 res_ptr += res_strides[idx] * index;
110 }
111 Slice::RangeTo(r) => {
112 let range_to = if *r >= 0 { ..*r } else { ..*r + shape[idx] };
113 let mut end = range_to.end;
114 end *= alpha;
115 if range_to.end > res_shape[idx] {
116 end = res_shape[idx];
117 }
118 res_shape[idx] = end;
119 }
120 // tested
121 Slice::Range((start, end)) => {
122 let range;
123 if *start >= 0 {
124 if *end >= 0 {
125 range = *start..*end;
126 } else {
127 range = *start..*end + shape[idx];
128 }
129 } else if *end >= 0 {
130 range = *start + shape[idx]..*end;
131 } else {
132 range = start + shape[idx]..*end + shape[idx];
133 }
134 let mut start = range.start;
135 start *= alpha;
136 let mut end = range.end;
137 end *= alpha;
138 if start >= res_shape[idx] {
139 start = res_shape[idx];
140 }
141 if end >= res_shape[idx] {
142 end = res_shape[idx];
143 }
144 if start > end {
145 res_shape[idx] = 0;
146 } else {
147 res_shape[idx] = end - start;
148 res_ptr += strides[idx] * start;
149 }
150 }
151 // tested
152 Slice::StepByRangeFromTo((start, end, step)) => {
153 let mut start = if *start >= 0 {
154 *start
155 } else {
156 *start + shape[idx]
157 };
158 let mut end = if *end >= 0 { *end } else { *end + shape[idx] };
159
160 if start >= shape[idx] {
161 start = shape[idx] - 1;
162 }
163 if end > shape[idx] {
164 end = shape[idx];
165 }
166
167 let length = if *step > 0 {
168 (end - start + step - 1) / step
169 } else if *step < 0 {
170 (end - start + step + 1) / step
171 } else {
172 0
173 };
174
175 if length > 0 {
176 res_shape[idx] = length * alpha;
177 res_ptr += start * res_strides[idx];
178 res_strides[idx] *= *step;
179 } else {
180 res_shape[idx] = 0;
181 }
182 }
183 // tested
184 Slice::StepByRangeFrom((start, step)) => {
185 let mut start = if *start >= 0 {
186 *start
187 } else {
188 *start + shape[idx]
189 };
190 let end = if *step > 0 { shape[idx] } else { 0 };
191 if start >= shape[idx] {
192 start = shape[idx] - 1;
193 }
194 let length;
195 if start <= end && *step > 0 {
196 length = (end - 1 - start + step) / step;
197 } else if start >= end && *step < 0 {
198 length = (end - start + step) / step;
199 } else {
200 length = 0;
201 }
202 if length == 1 {
203 res_shape[idx] = alpha;
204 res_ptr += res_strides[idx] * start;
205 } else if length >= 0 {
206 res_shape[idx] = length * alpha;
207 res_ptr += start * res_strides[idx];
208 res_strides[idx] *= *step;
209 } else {
210 res_shape[idx] = 0;
211 }
212 }
213 // tested
214 Slice::StepByFullRange(step) => {
215 let start = if *step > 0 { 0 } else { shape[idx] - 1 };
216 let end = if *step > 0 { shape[idx] - 1 } else { 0 };
217 let length = if (start <= end && *step > 0) || (start >= end && *step < 0) {
218 (end - start + step) / step
219 } else {
220 0
221 };
222 if length == 1 {
223 res_shape[idx] = alpha;
224 res_ptr += res_strides[idx] * start;
225 } else if length >= 0 {
226 res_shape[idx] = length * alpha;
227 res_ptr += start * res_strides[idx];
228 res_strides[idx] *= *step;
229 } else {
230 res_shape[idx] = 0;
231 }
232 }
233 _ => {}
234 }
235 }
236
237 let mut new_shape = Vec::new();
238 let mut new_strides = Vec::new();
239 for (i, &s) in res_shape.iter().enumerate() {
240 if s == 0 {
241 continue;
242 }
243 new_shape.push(s);
244 new_strides.push(res_strides[i]);
245 }
246 Ok((new_shape, new_strides, res_ptr))
247}
248
249/// slice operation for tensor
250/// slicing uses the same syntax as numpy
251///
252/// `[:::]` and `[:]` and `[::]`: load all the elements along the corresponding dimension
253///
254/// `[1:]`: load from the first element to the end along the corresponding dimension
255///
256/// `[:10]`: load from the beginning to index 9 along the corresponding dimension
257///
258/// `[1:10]`: load from index 1 to index 9 along the corresponding dimension
259///
260/// `[1:10:2]`: load from index 1 to index 9 with step 2 along the corresponding dimension
261///
262/// `[1:10:2, 2:10:3]`: load from index 1 to index 9 with step 2 for the first dimension, and load from index 2 to index 9 with step 3 for the second dimension
263///
264/// `[::2]`: load all the elements with step 2 along the corresponding dimension
265/// Example:
266/// ```
267/// use hpt::prelude::*;
268/// let a = Tensor::<f32>::rand([128, 128, 128])?;
269/// let res = slice!(a[::2]); // load all the elements with step 2 along the first dimension
270/// let res = slice!(a[1:10:2, 2:10:3]); // load from index 1 to index 9 with step 2 for the first dimension, and load from index 2 to index 9 with step 3 for the second dimension
271/// ```
272#[macro_export]
273macro_rules! slice {
274 (
275 $tensor:ident [$($indexes:tt)*]
276 ) => {
277 $tensor.slice(&match_selection!($($indexes)*))
278 };
279}