vectra 0.2.4

A multi-dimensional array library for Rust, similar to NumPy
Documentation
use num_traits::Zero;

use crate::core::MajorOrder;

pub fn dyn_dim_to_static<const D: usize, T: Copy + Zero>(idx: &[T]) -> [T; D] {
    let mut result = [T::zero(); D];
    result.copy_from_slice(idx);

    result
}

pub fn compute_strides<const D: usize>(shape: [usize; D], major_order: MajorOrder) -> [usize; D] {
    let mut strides_d = vec![1; shape.len()];

    match major_order {
        MajorOrder::RowMajor => {
            for i in (0..shape.len() - 1).rev() {
                strides_d[i] = strides_d[i + 1] * shape[i + 1];
            }
        }
        MajorOrder::ColumnMajor => {
            for i in 1..shape.len() {
                strides_d[i] = strides_d[i - 1] * shape[i - 1];
            }
        }
    }

    let mut strides = [0; D];
    strides.copy_from_slice(&strides_d);

    strides
}

#[allow(unused)]
pub fn flat_idx_to_indices<const D: usize>(
    shape: [usize; D],
    flat_idx: usize,
    major_order: MajorOrder,
) -> [usize; D] {
    let mut indices = [0; D];
    let mut remaining = flat_idx;

    match major_order {
        MajorOrder::RowMajor => {
            for (i, &dim) in shape.iter().enumerate().rev() {
                indices[i] = remaining % dim;
                remaining /= dim;
            }
        }
        MajorOrder::ColumnMajor => {
            for (i, &dim) in shape.iter().enumerate() {
                indices[i] = remaining % dim;
                remaining /= dim;
            }
        }
    }

    indices
}

pub fn indices_to_flat_idx<const D: usize>(strides: [usize; D], indices: [usize; D]) -> usize {
    strides.into_iter().zip(indices).map(|(s, i)| s * i).sum()
}

pub fn shape_indices_to_flat_idx<const D: usize>(
    shape: [usize; D],
    indices: [usize; D],
    major_order: MajorOrder,
) -> usize {
    let strides = compute_strides(shape, major_order);
    indices_to_flat_idx(strides, indices)
}

pub fn negative_indices_to_positive<const D: usize>(
    indices: [isize; D],
    shape: [usize; D],
) -> [usize; D] {
    let v: Vec<_> = indices
        .into_iter()
        .zip(shape.into_iter())
        .map(|(idx, count)| negative_idx_to_positive(idx, count))
        .collect();

    let mut res = [0; D];
    res.copy_from_slice(&v);

    res
}

pub fn negative_idx_to_positive(mut idx: isize, guard: usize) -> usize {
    let guard_i = guard as isize;
    if idx < -guard_i || idx >= guard_i {
        panic!("idx must in -{}..{}", guard, guard);
    }

    if idx >= 0 {
        idx as usize
    } else {
        loop {
            if idx >= 0 {
                break idx as usize;
            }

            idx += guard_i;
        }
    }
}

/// Broadcast two shapes to a common shape
pub fn broadcast_shapes<const D: usize>(
    shape1: [usize; D],
    shape2: [usize; D],
) -> Result<[usize; D], String> {
    let mut result = [1; D];

    for i in 0..D {
        let idx = D - 1 - i;
        let dim1 = shape1[idx];
        let dim2 = shape2[idx];

        if dim1 == dim2 {
            result[idx] = dim1;
        } else if dim1 == 1 {
            result[idx] = dim2;
        } else if dim2 == 1 {
            result[idx] = dim1;
        } else {
            return Err(format!(
                "Cannot broadcast shapes {:?} and {:?}",
                shape1, shape2
            ));
        }
    }

    Ok(result)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_shape_strides_idx() {
        let shape = [2, 3];

        let row_strides = compute_strides(shape, MajorOrder::RowMajor);
        assert_eq!(row_strides, [3, 1]);
        let col_strides = compute_strides(shape, MajorOrder::ColumnMajor);
        assert_eq!(col_strides, [1, 2]);

        let row_idx = indices_to_flat_idx(row_strides, [1, 1]);
        let row_idx1 = shape_indices_to_flat_idx(shape, [1, 1], MajorOrder::RowMajor);
        assert_eq!(row_idx, 4);
        assert_eq!(row_idx, row_idx1);

        let col_idx = indices_to_flat_idx(col_strides, [1, 1]);
        let col_idx1 = shape_indices_to_flat_idx(shape, [1, 1], MajorOrder::ColumnMajor);
        assert_eq!(col_idx, 3);
        assert_eq!(col_idx, col_idx1);

        let indices_1 = flat_idx_to_indices(shape, row_idx, MajorOrder::RowMajor);
        let indices_2 = flat_idx_to_indices(shape, col_idx, MajorOrder::ColumnMajor);

        assert_eq!(indices_1, indices_2);
        assert_eq!(indices_1, [1, 1]);
        assert_eq!(indices_2, [1, 1]);
    }

    #[test]
    fn test_shape_broadcast() {
        let shape1 = [2, 3];
        let shape2 = [1, 3];
        let result = broadcast_shapes(shape1, shape2).unwrap();
        assert_eq!(result, [2, 3]);

        let shape1 = [2, 3];
        let shape2 = [1, 3];
        let result = broadcast_shapes(shape1, shape2).unwrap();
        assert_eq!(result, [2, 3]);

        let shape1 = [2, 3];
        let shape2 = [2, 3];
        let result = broadcast_shapes(shape1, shape2).unwrap();
        assert_eq!(result, [2, 3]);

        let shape1 = [2, 3];
        let shape2 = [2, 4];
        let result = broadcast_shapes(shape1, shape2);
        assert!(result.is_err());
    }
}