yscv-tensor 0.1.8

SIMD-accelerated tensor library with f32/f16/bf16 support and 80+ operations
Documentation
pub(crate) fn shape_element_count(shape: &[usize]) -> Option<usize> {
    shape
        .iter()
        .try_fold(1usize, |acc, dim| acc.checked_mul(*dim))
}

pub(crate) fn compute_strides(shape: &[usize]) -> Option<Vec<usize>> {
    let mut strides = vec![0usize; shape.len()];
    let mut stride = 1usize;
    for axis in (0..shape.len()).rev() {
        strides[axis] = stride;
        stride = stride.checked_mul(shape[axis])?;
    }
    Some(strides)
}

pub(crate) fn broadcast_shape(left: &[usize], right: &[usize]) -> Option<Vec<usize>> {
    let max_rank = left.len().max(right.len());
    let left_pad = max_rank - left.len();
    let right_pad = max_rank - right.len();
    let mut out = Vec::with_capacity(max_rank);

    for axis in 0..max_rank {
        let ldim = if axis < left_pad {
            1
        } else {
            left[axis - left_pad]
        };
        let rdim = if axis < right_pad {
            1
        } else {
            right[axis - right_pad]
        };

        if ldim == rdim {
            out.push(ldim);
        } else if ldim == 1 {
            out.push(rdim);
        } else if rdim == 1 {
            out.push(ldim);
        } else {
            return None;
        }
    }
    Some(out)
}

pub(crate) fn broadcast_offset(shape: &[usize], strides: &[usize], out_coords: &[usize]) -> usize {
    if shape.is_empty() {
        return 0;
    }

    let axis_offset = out_coords.len() - shape.len();
    let mut out = 0usize;
    for axis in 0..shape.len() {
        let coord = if shape[axis] == 1 {
            0
        } else {
            out_coords[axis + axis_offset]
        };
        out += coord * strides[axis];
    }
    out
}

pub(crate) fn increment_coords(coords: &mut [usize], shape: &[usize]) {
    if coords.is_empty() {
        return;
    }

    for axis in (0..coords.len()).rev() {
        coords[axis] += 1;
        if coords[axis] < shape[axis] {
            return;
        }
        coords[axis] = 0;
    }
}