zyx 0.14.0

Zyx machine learning library
Documentation
//! Few traits that describe shapes, axes, padding, etc.

use core::fmt::Debug;
use core::ops::{Add, Range, RangeInclusive};

pub(crate) type Dimension = usize;
pub(crate) type Axis = usize;

/// IntoShape trait
pub trait IntoShape: Clone + Debug {
    /// Convert value into shape (iterator over dimensions)
    fn into_shape(self) -> impl Iterator<Item = Dimension>;
    /// Get the rank of the shape
    fn rank(&self) -> usize;
}

impl IntoShape for Dimension {
    fn into_shape(self) -> impl Iterator<Item = Dimension> {
        return [self].into_iter();
    }

    fn rank(&self) -> usize {
        return 1;
    }
}

impl IntoShape for (Dimension, Dimension) {
    fn into_shape(self) -> impl Iterator<Item = Dimension> {
        return [self.0, self.1].into_iter();
    }

    fn rank(&self) -> usize {
        return 2;
    }
}

impl IntoShape for (Dimension, Dimension, Dimension) {
    fn into_shape(self) -> impl Iterator<Item = Dimension> {
        return [self.0, self.1, self.2].into_iter();
    }

    fn rank(&self) -> usize {
        return 3;
    }
}

impl<const N: usize> IntoShape for [Dimension; N] {
    fn into_shape(self) -> impl Iterator<Item = Dimension> {
        return self.into_iter();
    }

    fn rank(&self) -> usize {
        return N;
    }
}

impl IntoShape for &[Dimension] {
    fn into_shape(self) -> impl Iterator<Item = Dimension> {
        return self.into_iter().copied();
    }

    fn rank(&self) -> usize {
        return self.len();
    }
}

impl IntoShape for Vec<Dimension> {
    fn into_shape(self) -> impl Iterator<Item = Dimension> {
        return self.into_iter();
    }

    fn rank(&self) -> usize {
        return self.len();
    }
}

impl IntoShape for &Vec<Dimension> {
    fn into_shape(self) -> impl Iterator<Item = Dimension> {
        return self.into_iter().copied();
    }

    fn rank(&self) -> usize {
        return self.len();
    }
}

pub(crate) fn to_axis<T>(axis: T, rank: usize) -> usize
where
    usize: TryInto<T>,
    T: TryInto<usize>,
    T: Add<Output = T>,
    <usize as TryInto<T>>::Error: Debug,
    <T as TryInto<usize>>::Error: Debug,
{
    let t = axis + rank.try_into().unwrap();
    let t = <T as TryInto<usize>>::try_into(t).unwrap();
    let t = t % rank;
    return t;
}

pub trait IntoAxes: Clone {
    fn into_axes(self, rank: usize) -> impl Iterator<Item = usize>;
    fn len(&self) -> usize;
}

impl IntoAxes for () {
    fn into_axes(self, _: usize) -> impl Iterator<Item = usize> {
        return [].into_iter();
    }

    fn len(&self) -> usize {
        0
    }
}

impl IntoAxes for isize {
    fn into_axes(self, rank: usize) -> impl Iterator<Item = usize> {
        return [to_axis(self, rank)].into_iter();
    }

    fn len(&self) -> usize {
        1
    }
}

impl IntoAxes for Vec<isize> {
    fn into_axes(self, rank: usize) -> impl Iterator<Item = usize> {
        let n = self.len();
        self.into_iter()
            .map(move |a| to_axis(a, rank))
            .chain(if n == 0 { 0..rank } else { 0..0 })
    }

    fn len(&self) -> usize {
        self.len()
    }
}

impl<const N: usize> IntoAxes for [isize; N] {
    fn into_axes(self, rank: usize) -> impl Iterator<Item = usize> {
        self.into_iter()
            .map(move |a| to_axis(a, rank))
            .chain(if self.len() == 0 { 0..rank } else { 0..0 })
    }

    fn len(&self) -> usize {
        N
    }
}

impl IntoAxes for Range<isize> {
    fn into_axes(self, rank: usize) -> impl Iterator<Item = usize> {
        let n = ExactSizeIterator::len(&self);
        (to_axis(self.start, rank)..to_axis(self.end, rank)).chain(if n == 0 {
            0..rank
        } else {
            0..0
        })
    }

    fn len(&self) -> usize {
        (self.end - self.start) as usize
    }
}

impl IntoAxes for RangeInclusive<isize> {
    fn into_axes(self, rank: usize) -> impl Iterator<Item = usize> {
        (to_axis(*self.start(), rank)..to_axis(*self.end(), rank)).chain(if self.len() == 0 {
            0..rank
        } else {
            0..0
        })
    }

    fn len(&self) -> usize {
        (self.end() - self.start() + 1) as usize
    }
}

pub trait IntoPadding {
    fn into_padding(self) -> Vec<(isize, isize)>;
}

impl<I: IntoIterator<Item = (isize, isize)>> IntoPadding for I {
    fn into_padding(self) -> Vec<(isize, isize)> {
        self.into_iter().collect()
    }
}

pub(crate) fn permute(shape: &[usize], axes: &[usize]) -> Vec<usize> {
    assert_eq!(shape.len(), axes.len());
    axes.iter().map(|a| shape[*a]).collect()
}

pub(crate) fn reduce(shape: &[usize], axes: &[usize]) -> Vec<usize> {
    let res: Vec<usize> = shape
        .iter()
        .copied()
        .enumerate()
        .filter_map(|(i, d)| if axes.contains(&i) { None } else { Some(d) })
        .collect();
    if res.is_empty() {
        vec![1]
    } else {
        res
    }
}