hotg_rune_core/
shape.rs

1use alloc::{
2    borrow::Cow,
3    string::{String, ToString},
4    vec::Vec,
5};
6use core::{
7    fmt::{self, Formatter, Display},
8    num::ParseIntError,
9    str::FromStr,
10};
11use crate::element_type::ElementType;
12
13/// A tensor's shape.
14#[derive(
15    Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize,
16)]
17pub struct Shape<'a> {
18    element_type: ElementType,
19    dimensions: Cow<'a, [usize]>,
20}
21
22impl<'a> Shape<'a> {
23    pub fn new(
24        element_type: ElementType,
25        dimensions: impl Into<Cow<'a, [usize]>>,
26    ) -> Self {
27        Shape {
28            element_type,
29            dimensions: dimensions.into(),
30        }
31    }
32
33    pub fn element_type(&self) -> ElementType { self.element_type }
34
35    pub fn dimensions(&self) -> &[usize] { &self.dimensions }
36
37    /// The number of bytes this tensor would take up, if it has a fized size.
38    pub fn size(&self) -> Option<usize> {
39        let element_size = self.element_type.size_of()?;
40
41        Some(self.dimensions.iter().product::<usize>() * element_size)
42    }
43
44    pub fn to_owned(&self) -> Shape<'static> {
45        let Shape {
46            element_type,
47            dimensions,
48        } = self;
49
50        Shape::new(*element_type, dimensions.clone().into_owned())
51    }
52}
53
54impl<'a> Display for Shape<'a> {
55    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
56        let Shape {
57            element_type,
58            dimensions,
59        } = self;
60        write!(f, "{}[", element_type)?;
61
62        for (i, dim) in dimensions.iter().enumerate() {
63            if i > 0 {
64                write!(f, ", ")?;
65            }
66
67            write!(f, "{}", dim)?;
68        }
69
70        write!(f, "]")?;
71        Ok(())
72    }
73}
74
75impl FromStr for Shape<'static> {
76    type Err = FormatError;
77
78    fn from_str(s: &str) -> Result<Self, Self::Err> {
79        let opening_bracket = s.find('[').ok_or(FormatError::Malformed)?;
80        let element_type = s[..opening_bracket].trim();
81        let ty = element_type.parse().map_err(|_| {
82            FormatError::UnknownElementType {
83                found: element_type.to_string(),
84            }
85        })?;
86
87        let closing_bracket = s.rfind(']').ok_or(FormatError::Malformed)?;
88
89        let between_brackets = &s[opening_bracket + 1..closing_bracket];
90
91        let mut dimensions = Vec::new();
92
93        for word in between_brackets.split(',') {
94            let word = word.trim();
95            let dimension = word.parse::<usize>().map_err(|e| {
96                FormatError::BadDimension {
97                    found: word.to_string(),
98                    reason: e,
99                }
100            })?;
101            dimensions.push(dimension);
102        }
103
104        Ok(Shape {
105            element_type: ty,
106            dimensions: dimensions.into(),
107        })
108    }
109}
110
111#[derive(Debug, Clone, PartialEq)]
112pub enum FormatError {
113    Malformed,
114    UnknownElementType {
115        found: String,
116    },
117    BadDimension {
118        found: String,
119        reason: ParseIntError,
120    },
121}
122
123impl Display for FormatError {
124    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
125        match self {
126            FormatError::Malformed => write!(f, "Malformed shape"),
127            FormatError::UnknownElementType { found } => {
128                write!(f, "Couldn't recognise the \"{}\" element type", found)
129            },
130            FormatError::BadDimension { found, .. } => {
131                write!(f, "\"{}\" isn't a valid dimension", found)
132            },
133        }
134    }
135}
136
137#[cfg(feature = "std")]
138impl std::error::Error for FormatError {
139    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
140        match self {
141            FormatError::BadDimension { reason, .. } => Some(reason),
142            _ => None,
143        }
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use std::prelude::v1::*;
151
152    const SHAPES: &[(Shape, &str)] = &[
153        (
154            Shape {
155                element_type: ElementType::F32,
156                dimensions: Cow::Borrowed(&[1, 2, 3]),
157            },
158            "f32[1, 2, 3]",
159        ),
160        (
161            Shape {
162                element_type: ElementType::U8,
163                dimensions: Cow::Borrowed(&[42]),
164            },
165            "u8[42]",
166        ),
167    ];
168
169    #[test]
170    fn shape_format() {
171        for (shape, should_be) in SHAPES.iter().cloned() {
172            let got = shape.to_string();
173            assert_eq!(got, should_be);
174        }
175    }
176
177    #[test]
178    fn parse() {
179        for (should_be, src) in SHAPES.iter().cloned() {
180            let got: Shape = src.parse().unwrap();
181            assert_eq!(got, should_be);
182        }
183    }
184}