acme_tensor/types/
tensors.rs

1/*
2    Appellation: tensors <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::shape::Rank;
6use crate::tensor::TensorBase;
7use strum::{Display, EnumCount, EnumDiscriminants, EnumIs, EnumIter, EnumString, VariantNames};
8
9#[derive(Clone, Debug, EnumCount, EnumDiscriminants, EnumIs, Eq, PartialEq)]
10#[cfg_attr(
11    feature = "serde",
12    derive(serde::Deserialize, serde::Serialize),
13    serde(rename_all = "lowercase"),
14    strum_discriminants(derive(serde::Deserialize, serde::Serialize))
15)]
16#[repr(C)]
17#[strum(serialize_all = "lowercase")]
18#[strum_discriminants(
19    derive(
20        Display,
21        EnumCount,
22        EnumIs,
23        EnumIter,
24        EnumString,
25        Hash,
26        Ord,
27        PartialOrd,
28        VariantNames
29    ),
30    name(TensorType)
31)]
32pub enum Tensors<T> {
33    Scalar(T),
34    Tensor(TensorBase<T>),
35}
36
37impl<T> Tensors<T> {
38    pub fn scalar(scalar: T) -> Self {
39        Self::Scalar(scalar)
40    }
41
42    pub fn tensor(tensor: TensorBase<T>) -> Self {
43        Self::Tensor(tensor)
44    }
45
46    pub fn rank(&self) -> Rank {
47        match self {
48            Self::Tensor(tensor) => tensor.rank(),
49            _ => Rank::scalar(),
50        }
51    }
52}
53
54impl<T> From<TensorBase<T>> for Tensors<T>
55where
56    T: Clone,
57{
58    fn from(tensor: TensorBase<T>) -> Self {
59        if tensor.rank().is_scalar() {
60            Self::Scalar(unsafe { tensor.into_scalar() })
61        } else {
62            Self::Tensor(tensor)
63        }
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70
71    #[test]
72    fn test_tensor_type() {
73        let shape = (2, 3);
74        let tensor = TensorBase::<f64>::ones(shape);
75        let item = Tensors::tensor(tensor);
76
77        assert_eq!(item.rank(), Rank::from(2));
78    }
79}