acme_tensor/types/
tensors.rs1use crate::shape::Rank;
6use crate::tensor::TensorBase;
7use strum::{Display, EnumCount, EnumDiscriminants, EnumIs, EnumIter, EnumString, VariantNames};
8
9#[derive(Clone, Debug, EnumCount, EnumDiscriminants, EnumIs, Eq, PartialEq)]
10#[cfg_attr(
11 feature = "serde",
12 derive(serde::Deserialize, serde::Serialize),
13 serde(rename_all = "lowercase"),
14 strum_discriminants(derive(serde::Deserialize, serde::Serialize))
15)]
16#[repr(C)]
17#[strum(serialize_all = "lowercase")]
18#[strum_discriminants(
19 derive(
20 Display,
21 EnumCount,
22 EnumIs,
23 EnumIter,
24 EnumString,
25 Hash,
26 Ord,
27 PartialOrd,
28 VariantNames
29 ),
30 name(TensorType)
31)]
32pub enum Tensors<T> {
33 Scalar(T),
34 Tensor(TensorBase<T>),
35}
36
37impl<T> Tensors<T> {
38 pub fn scalar(scalar: T) -> Self {
39 Self::Scalar(scalar)
40 }
41
42 pub fn tensor(tensor: TensorBase<T>) -> Self {
43 Self::Tensor(tensor)
44 }
45
46 pub fn rank(&self) -> Rank {
47 match self {
48 Self::Tensor(tensor) => tensor.rank(),
49 _ => Rank::scalar(),
50 }
51 }
52}
53
54impl<T> From<TensorBase<T>> for Tensors<T>
55where
56 T: Clone,
57{
58 fn from(tensor: TensorBase<T>) -> Self {
59 if tensor.rank().is_scalar() {
60 Self::Scalar(unsafe { tensor.into_scalar() })
61 } else {
62 Self::Tensor(tensor)
63 }
64 }
65}
66
67#[cfg(test)]
68mod tests {
69 use super::*;
70
71 #[test]
72 fn test_tensor_type() {
73 let shape = (2, 3);
74 let tensor = TensorBase::<f64>::ones(shape);
75 let item = Tensors::tensor(tensor);
76
77 assert_eq!(item.rank(), Rank::from(2));
78 }
79}