use crate::error::base::TensorError;
#[track_caller]
pub fn slice_process(
shape: Vec<i64>,
strides: Vec<i64>,
index: &[(i64, i64, i64)],
alpha: i64,
) -> std::result::Result<(Vec<i64>, Vec<i64>, i64), TensorError> {
let mut res_shape: Vec<i64> = shape.clone();
let mut res_strides: Vec<i64> = strides.clone();
res_shape.iter_mut().for_each(|x| {
*x *= alpha;
});
res_strides.iter_mut().for_each(|x| {
*x *= alpha;
});
let mut res_ptr = 0;
if index.len() > res_shape.len() {
panic!("index length is greater than the shape length");
}
let mut new_indices = Vec::with_capacity(shape.len());
let ellipsis_pos = index
.iter()
.position(|&idx| idx == (0, 0, 0x7FFFFFFFFFFFFFFF));
if let Some(pos) = ellipsis_pos {
let missing_dims = shape.len() - (index.len() - 1);
new_indices.extend_from_slice(&index[0..pos]);
for _ in 0..missing_dims {
new_indices.push((0, 0x7FFFFFFFFFFFFFFF, 1));
}
new_indices.extend_from_slice(&index[pos + 1..]);
} else {
new_indices = index.to_vec();
}
for (idx, (start, mut end, step)) in new_indices.into_iter().enumerate() {
if end == 0x7FFFFFFFFFFFFFFF {
end = shape[idx];
}
let mut start = if start >= 0 {
start
} else {
start + shape[idx]
};
let mut end = if end >= 0 { end } else { end + shape[idx] };
if start >= shape[idx] {
start = shape[idx] - 1;
}
if end > shape[idx] {
end = shape[idx];
}
let length = if step > 0 {
(end - start + step - 1) / step
} else if step < 0 {
(end - start + step + 1) / step
} else {
0
};
if length > 0 {
res_shape[idx] = length * alpha;
res_ptr += start * res_strides[idx];
res_strides[idx] *= step;
} else {
res_shape[idx] = 0;
}
}
let mut new_shape = Vec::new();
let mut new_strides = Vec::new();
for (i, &s) in res_shape.iter().enumerate() {
if s == 0 {
continue;
}
new_shape.push(s);
new_strides.push(res_strides[i]);
}
Ok((new_shape, new_strides, res_ptr))
}