Skip to main content

gpu_strided_split

Function gpu_strided_split 

Source
pub fn gpu_strided_split(
    input: &CudaBuffer<f32>,
    total_along_axis: usize,
    split_offset: usize,
    split_size: usize,
    inner_size: usize,
    n: usize,
    device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>>
Expand description

Extract a sub-tensor along one axis entirely on GPU.

Given an input buffer representing a tensor with total_along_axis elements along the split axis, extracts the slice [split_offset .. split_offset + split_size] along that axis.

  • inner_size = product of dimensions after the split axis.
  • n = total number of output elements (outer * split_size * inner_size).

ยงErrors