use-ml-tensor 0.0.1

Tensor shape and metadata primitives for RustUse machine-learning workflows.
Documentation
#![forbid(unsafe_code)]
#![doc = include_str!("../README.md")]

use core::{fmt, str::FromStr};
use std::error::Error;

pub const MAX_TENSOR_RANK: usize = 64;

pub mod prelude {
    pub use crate::{
        MAX_TENSOR_RANK, TensorAxis, TensorDType, TensorDeviceKind, TensorDim, TensorLayout,
        TensorMemoryFormat, TensorRank, TensorShape, TensorShapeError,
    };
}

#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct TensorShape {
    dims: Vec<usize>,
}

impl TensorShape {
    pub fn new(dims: impl Into<Vec<usize>>) -> Result<Self, TensorShapeError> {
        let dims = dims.into();
        if dims.len() > MAX_TENSOR_RANK {
            return Err(TensorShapeError::RankTooLarge {
                rank: dims.len(),
                max: MAX_TENSOR_RANK,
            });
        }

        Ok(Self { dims })
    }

    pub fn scalar() -> Self {
        Self { dims: Vec::new() }
    }

    pub fn rank(&self) -> usize {
        self.dims.len()
    }

    pub fn dims(&self) -> &[usize] {
        &self.dims
    }

    pub fn num_elements(&self) -> Option<usize> {
        self.dims
            .iter()
            .copied()
            .try_fold(1_usize, usize::checked_mul)
    }

    pub fn is_scalar(&self) -> bool {
        self.rank() == 0
    }

    pub fn is_vector(&self) -> bool {
        self.rank() == 1
    }

    pub fn is_matrix(&self) -> bool {
        self.rank() == 2
    }
}

#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct TensorDim(usize);

impl TensorDim {
    pub const fn new(value: usize) -> Self {
        Self(value)
    }

    pub const fn get(self) -> usize {
        self.0
    }
}

impl fmt::Display for TensorDim {
    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
        self.0.fmt(formatter)
    }
}

#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct TensorRank(usize);

impl TensorRank {
    pub fn new(value: usize) -> Result<Self, TensorShapeError> {
        if value > MAX_TENSOR_RANK {
            Err(TensorShapeError::RankTooLarge {
                rank: value,
                max: MAX_TENSOR_RANK,
            })
        } else {
            Ok(Self(value))
        }
    }

    pub const fn get(self) -> usize {
        self.0
    }
}

#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct TensorAxis(usize);

impl TensorAxis {
    pub fn new(value: usize, rank: TensorRank) -> Result<Self, TensorShapeError> {
        if value < rank.get() {
            Ok(Self(value))
        } else {
            Err(TensorShapeError::AxisOutOfBounds {
                axis: value,
                rank: rank.get(),
            })
        }
    }

    pub const fn unchecked(value: usize) -> Self {
        Self(value)
    }

    pub const fn index(self) -> usize {
        self.0
    }
}

macro_rules! tensor_enum {
    ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
        #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
        pub enum $name {
            $($variant),+
        }

        impl $name {
            pub const fn as_str(self) -> &'static str {
                match self {
                    $(Self::$variant => $label),+
                }
            }
        }

        impl fmt::Display for $name {
            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
                formatter.write_str(self.as_str())
            }
        }

        impl FromStr for $name {
            type Err = TensorShapeError;

            fn from_str(value: &str) -> Result<Self, Self::Err> {
                match normalized_label(value)?.as_str() {
                    $($label => Ok(Self::$variant),)+
                    _ => Err(TensorShapeError::UnknownLabel),
                }
            }
        }
    };
}

tensor_enum!(TensorDType {
    Bool => "bool",
    Int8 => "int8",
    Int16 => "int16",
    Int32 => "int32",
    Int64 => "int64",
    Uint8 => "uint8",
    Uint16 => "uint16",
    Uint32 => "uint32",
    Uint64 => "uint64",
    Float16 => "float16",
    BFloat16 => "bfloat16",
    Float32 => "float32",
    Float64 => "float64",
    Complex64 => "complex64",
    Complex128 => "complex128",
    String => "string",
    Unknown => "unknown",
});

tensor_enum!(TensorLayout {
    Dense => "dense",
    Sparse => "sparse",
    Ragged => "ragged",
    Quantized => "quantized",
    BlockSparse => "block-sparse",
    Unknown => "unknown",
});

tensor_enum!(TensorDeviceKind {
    Cpu => "cpu",
    Gpu => "gpu",
    Tpu => "tpu",
    Npu => "npu",
    Metal => "metal",
    Vulkan => "vulkan",
    Wasm => "wasm",
    Unknown => "unknown",
});

tensor_enum!(TensorMemoryFormat {
    Contiguous => "contiguous",
    ChannelsLast => "channels-last",
    ChannelsFirst => "channels-first",
    Sparse => "sparse",
    Unknown => "unknown",
});

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum TensorShapeError {
    EmptyLabel,
    UnknownLabel,
    RankTooLarge { rank: usize, max: usize },
    AxisOutOfBounds { axis: usize, rank: usize },
}

impl fmt::Display for TensorShapeError {
    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::EmptyLabel => formatter.write_str("tensor metadata label cannot be empty"),
            Self::UnknownLabel => formatter.write_str("unknown tensor metadata label"),
            Self::RankTooLarge { rank, max } => {
                write!(formatter, "tensor rank {rank} exceeds maximum rank {max}")
            },
            Self::AxisOutOfBounds { axis, rank } => {
                write!(formatter, "tensor axis {axis} is outside rank {rank}")
            },
        }
    }
}

impl Error for TensorShapeError {}

fn normalized_label(value: &str) -> Result<String, TensorShapeError> {
    let trimmed = value.trim();
    if trimmed.is_empty() {
        Err(TensorShapeError::EmptyLabel)
    } else {
        Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
    }
}

#[cfg(test)]
mod tests {
    use super::{
        TensorAxis, TensorDType, TensorDeviceKind, TensorLayout, TensorMemoryFormat, TensorRank,
        TensorShape, TensorShapeError,
    };

    #[test]
    fn models_tensor_shapes() -> Result<(), TensorShapeError> {
        let scalar = TensorShape::scalar();
        let vector = TensorShape::new([3])?;
        let matrix = TensorShape::new([2, 3])?;
        let tensor = TensorShape::new([2, 3, 4])?;

        assert!(scalar.is_scalar());
        assert!(vector.is_vector());
        assert!(matrix.is_matrix());
        assert_eq!(tensor.rank(), 3);
        assert_eq!(tensor.dims(), &[2, 3, 4]);
        assert_eq!(tensor.num_elements(), Some(24));
        Ok(())
    }

    #[test]
    fn protects_element_count_overflow() -> Result<(), TensorShapeError> {
        let shape = TensorShape::new([usize::MAX, 2])?;

        assert_eq!(shape.num_elements(), None);
        Ok(())
    }

    #[test]
    fn validates_rank_and_axis() -> Result<(), TensorShapeError> {
        let rank = TensorRank::new(3)?;
        let axis = TensorAxis::new(2, rank)?;

        assert_eq!(rank.get(), 3);
        assert_eq!(axis.index(), 2);
        assert_eq!(
            TensorAxis::new(3, rank),
            Err(TensorShapeError::AxisOutOfBounds { axis: 3, rank: 3 })
        );
        Ok(())
    }

    #[test]
    fn displays_and_parses_tensor_enums() -> Result<(), TensorShapeError> {
        assert_eq!("float32".parse::<TensorDType>()?, TensorDType::Float32);
        assert_eq!(
            "block sparse".parse::<TensorLayout>()?,
            TensorLayout::BlockSparse
        );
        assert_eq!("gpu".parse::<TensorDeviceKind>()?, TensorDeviceKind::Gpu);
        assert_eq!(
            "channels_last".parse::<TensorMemoryFormat>()?,
            TensorMemoryFormat::ChannelsLast
        );
        Ok(())
    }
}