hpt_common/
slice.rs

1use crate::error::{base::TensorError, shape::ShapeError};
2
3/// Slice enum to hold the slice information
4///
5/// it stores the slice information the user wants to perform operations on
6///
7/// it is not being used directly by the user, but is used by the library internally
8#[derive(Debug, Clone)]
9pub enum Slice {
10    /// load the element at the index
11    From(i64),
12    /// load all the elements along the corresponding dimension
13    Full,
14    /// load from the first element to the end along the corresponding dimension
15    RangeFrom(i64),
16    /// load from the beginning to specified index along the corresponding dimension
17    RangeTo(i64),
18    /// load from the start index to the end index along the corresponding dimension
19    Range((i64, i64)),
20    /// load from the start index to the end index with step along the corresponding dimension
21    StepByRangeFrom((i64, i64)),
22    /// load all the elements with step along the corresponding dimension
23    StepByFullRange(i64),
24    /// load from the start index to the end index with step along the corresponding dimension
25    StepByRangeFromTo((i64, i64, i64)),
26    /// load from the start index to the end index with step along the corresponding dimension
27    StepByRangeTo((i64, i64)),
28}
29
30/// # Internal Function
31/// Processes tensor slicing with given strides and shape, adjusting strides and shape
32/// based on the slicing operation and applying an additional scaling factor `alpha`.
33///
34/// This function performs slicing operations on a tensor's shape and strides according to
35/// the provided `index` and scales both the shape and strides by a factor of `alpha`.
36///
37/// # Arguments
38/// - `shape`: A `Vec<i64>` representing the shape of the tensor.
39/// - `strides`: A `Vec<i64>` representing the original strides of the tensor.
40/// - `index`: A slice of `Slice` enums that specify the slicing operations to apply to each dimension.
41/// - `alpha`: A scaling factor of type `i64` that is applied to both the shape and strides.
42///
43/// # Returns
44/// This function returns a `Result` with the following tuple upon success:
45/// - `Vec<i64>`: The new shape of the tensor after applying the slicing and scaling.
46/// - `Vec<i64>`: The new strides after applying the slicing and scaling.
47/// - `i64`: The adjusted pointer offset based on the slicing.
48///
49/// If the `index` length is out of range for the given `shape`, it returns an error.
50///
51/// # Errors
52/// - Returns an error if the `index` length exceeds the number of dimensions in the tensor shape.
53/// - Returns an error if a slicing operation goes out of the bounds of the tensor's shape.
54///
55/// # Examples
56/// ```
57/// use hpt_common::slice_process;
58/// use hpt_types::Slice;
59///
60/// let shape = vec![3, 4, 5];
61/// let strides = vec![20, 5, 1];
62/// let index = vec![Slice::From(1), Slice::Range((0, 3)), Slice::StepByFullRange(2)];
63/// let alpha = 1;
64/// let result = slice_process(shape, strides, &index, alpha).unwrap();
65/// assert_eq!(result, (vec![2, 3, 3], vec![20, 5, 2], 20));
66/// ```
67#[track_caller]
68pub fn slice_process(
69    shape: Vec<i64>,
70    strides: Vec<i64>,
71    index: &[Slice],
72    alpha: i64,
73) -> std::result::Result<(Vec<i64>, Vec<i64>, i64), TensorError> {
74    let mut res_shape: Vec<i64> = shape.clone();
75    let mut res_strides: Vec<i64> = strides.clone();
76    res_shape.iter_mut().for_each(|x| {
77        *x *= alpha;
78    });
79    res_strides.iter_mut().for_each(|x| {
80        *x *= alpha;
81    });
82    let mut res_ptr = 0;
83    if index.len() > res_shape.len() {
84        panic!("index length is greater than the shape length");
85    }
86    for (idx, slice) in index.iter().enumerate() {
87        match slice {
88            Slice::From(mut __index) => {
89                let mut index;
90                if __index >= 0 {
91                    index = __index;
92                } else {
93                    index = __index + shape[idx];
94                }
95                index *= alpha;
96                ShapeError::check_index_out_of_range(index, shape[idx])?;
97                res_shape[idx] = alpha;
98                res_ptr += res_strides[idx] * index;
99            }
100            // tested
101            Slice::RangeFrom(mut __index) => {
102                let index = if __index >= 0 {
103                    __index
104                } else {
105                    __index + shape[idx]
106                };
107                let length = (shape[idx] - index) * alpha;
108                res_shape[idx] = if length > 0 { length } else { 0 };
109                res_ptr += res_strides[idx] * index;
110            }
111            Slice::RangeTo(r) => {
112                let range_to = if *r >= 0 { ..*r } else { ..*r + shape[idx] };
113                let mut end = range_to.end;
114                end *= alpha;
115                if range_to.end > res_shape[idx] {
116                    end = res_shape[idx];
117                }
118                res_shape[idx] = end;
119            }
120            // tested
121            Slice::Range((start, end)) => {
122                let range;
123                if *start >= 0 {
124                    if *end >= 0 {
125                        range = *start..*end;
126                    } else {
127                        range = *start..*end + shape[idx];
128                    }
129                } else if *end >= 0 {
130                    range = *start + shape[idx]..*end;
131                } else {
132                    range = start + shape[idx]..*end + shape[idx];
133                }
134                let mut start = range.start;
135                start *= alpha;
136                let mut end = range.end;
137                end *= alpha;
138                if start >= res_shape[idx] {
139                    start = res_shape[idx];
140                }
141                if end >= res_shape[idx] {
142                    end = res_shape[idx];
143                }
144                if start > end {
145                    res_shape[idx] = 0;
146                } else {
147                    res_shape[idx] = end - start;
148                    res_ptr += strides[idx] * start;
149                }
150            }
151            // tested
152            Slice::StepByRangeFromTo((start, end, step)) => {
153                let mut start = if *start >= 0 {
154                    *start
155                } else {
156                    *start + shape[idx]
157                };
158                let mut end = if *end >= 0 { *end } else { *end + shape[idx] };
159
160                if start >= shape[idx] {
161                    start = shape[idx] - 1;
162                }
163                if end > shape[idx] {
164                    end = shape[idx];
165                }
166
167                let length = if *step > 0 {
168                    (end - start + step - 1) / step
169                } else if *step < 0 {
170                    (end - start + step + 1) / step
171                } else {
172                    0
173                };
174
175                if length > 0 {
176                    res_shape[idx] = length * alpha;
177                    res_ptr += start * res_strides[idx];
178                    res_strides[idx] *= *step;
179                } else {
180                    res_shape[idx] = 0;
181                }
182            }
183            // tested
184            Slice::StepByRangeFrom((start, step)) => {
185                let mut start = if *start >= 0 {
186                    *start
187                } else {
188                    *start + shape[idx]
189                };
190                let end = if *step > 0 { shape[idx] } else { 0 };
191                if start >= shape[idx] {
192                    start = shape[idx] - 1;
193                }
194                let length;
195                if start <= end && *step > 0 {
196                    length = (end - 1 - start + step) / step;
197                } else if start >= end && *step < 0 {
198                    length = (end - start + step) / step;
199                } else {
200                    length = 0;
201                }
202                if length == 1 {
203                    res_shape[idx] = alpha;
204                    res_ptr += res_strides[idx] * start;
205                } else if length >= 0 {
206                    res_shape[idx] = length * alpha;
207                    res_ptr += start * res_strides[idx];
208                    res_strides[idx] *= *step;
209                } else {
210                    res_shape[idx] = 0;
211                }
212            }
213            // tested
214            Slice::StepByFullRange(step) => {
215                let start = if *step > 0 { 0 } else { shape[idx] - 1 };
216                let end = if *step > 0 { shape[idx] - 1 } else { 0 };
217                let length = if (start <= end && *step > 0) || (start >= end && *step < 0) {
218                    (end - start + step) / step
219                } else {
220                    0
221                };
222                if length == 1 {
223                    res_shape[idx] = alpha;
224                    res_ptr += res_strides[idx] * start;
225                } else if length >= 0 {
226                    res_shape[idx] = length * alpha;
227                    res_ptr += start * res_strides[idx];
228                    res_strides[idx] *= *step;
229                } else {
230                    res_shape[idx] = 0;
231                }
232            }
233            _ => {}
234        }
235    }
236
237    let mut new_shape = Vec::new();
238    let mut new_strides = Vec::new();
239    for (i, &s) in res_shape.iter().enumerate() {
240        if s == 0 {
241            continue;
242        }
243        new_shape.push(s);
244        new_strides.push(res_strides[i]);
245    }
246    Ok((new_shape, new_strides, res_ptr))
247}
248
249/// slice operation for tensor
250/// slicing uses the same syntax as numpy
251///
252/// `[:::]` and `[:]` and `[::]`: load all the elements along the corresponding dimension
253///
254/// `[1:]`: load from the first element to the end along the corresponding dimension
255///
256/// `[:10]`: load from the beginning to index 9 along the corresponding dimension
257///
258/// `[1:10]`: load from index 1 to index 9 along the corresponding dimension
259///
260/// `[1:10:2]`: load from index 1 to index 9 with step 2 along the corresponding dimension
261///
262/// `[1:10:2, 2:10:3]`: load from index 1 to index 9 with step 2 for the first dimension, and load from index 2 to index 9 with step 3 for the second dimension
263///
264/// `[::2]`: load all the elements with step 2 along the corresponding dimension
265/// Example:
266/// ```
267/// use hpt::prelude::*;
268/// let a = Tensor::<f32>::rand([128, 128, 128])?;
269/// let res = slice!(a[::2]); // load all the elements with step 2 along the first dimension
270/// let res = slice!(a[1:10:2, 2:10:3]); // load from index 1 to index 9 with step 2 for the first dimension, and load from index 2 to index 9 with step 3 for the second dimension
271/// ```
272#[macro_export]
273macro_rules! slice {
274    (
275        $tensor:ident [$($indexes:tt)*]
276    ) => {
277        $tensor.slice(&match_selection!($($indexes)*))
278    };
279}