slice_process

Function slice_process 

Source
pub fn slice_process(
    shape: Vec<i64>,
    strides: Vec<i64>,
    index: &[(i64, i64, i64)],
    alpha: i64,
) -> Result<(Vec<i64>, Vec<i64>, i64), TensorError>
Expand description

§Internal Function

Processes tensor slicing with given strides and shape, adjusting strides and shape based on the slicing operation and applying an additional scaling factor alpha.

This function performs slicing operations on a tensor’s shape and strides according to the provided index and scales both the shape and strides by a factor of alpha.

§Arguments

  • shape: A Vec<i64> representing the shape of the tensor.
  • strides: A Vec<i64> representing the original strides of the tensor.
  • index: A slice of Slice enums that specify the slicing operations to apply to each dimension.
  • alpha: A scaling factor of type i64 that is applied to both the shape and strides.

§Returns

This function returns a Result with the following tuple upon success:

  • Vec<i64>: The new shape of the tensor after applying the slicing and scaling.
  • Vec<i64>: The new strides after applying the slicing and scaling.
  • i64: The adjusted pointer offset based on the slicing.

If the index length is out of range for the given shape, it returns an error.

§Errors

  • Returns an error if the index length exceeds the number of dimensions in the tensor shape.
  • Returns an error if a slicing operation goes out of the bounds of the tensor’s shape.

§Examples

use hpt_common::slice_process;
use hpt_types::Slice;

let shape = vec![3, 4, 5];
let strides = vec![20, 5, 1];
let index = vec![Slice::From(1), Slice::Range((0, 3)), Slice::StepByFullRange(2)];
let alpha = 1;
let result = slice_process(shape, strides, &index, alpha).unwrap();
assert_eq!(result, (vec![2, 3, 3], vec![20, 5, 2], 20));