1use alloc::{
2 borrow::Cow,
3 string::{String, ToString},
4 vec::Vec,
5};
6use core::{
7 fmt::{self, Formatter, Display},
8 num::ParseIntError,
9 str::FromStr,
10};
11use crate::element_type::ElementType;
12
13#[derive(
15 Debug, Clone, PartialEq, Hash, serde::Serialize, serde::Deserialize,
16)]
17pub struct Shape<'a> {
18 element_type: ElementType,
19 dimensions: Cow<'a, [usize]>,
20}
21
22impl<'a> Shape<'a> {
23 pub fn new(
24 element_type: ElementType,
25 dimensions: impl Into<Cow<'a, [usize]>>,
26 ) -> Self {
27 Shape {
28 element_type,
29 dimensions: dimensions.into(),
30 }
31 }
32
33 pub fn element_type(&self) -> ElementType { self.element_type }
34
35 pub fn dimensions(&self) -> &[usize] { &self.dimensions }
36
37 pub fn size(&self) -> Option<usize> {
39 let element_size = self.element_type.size_of()?;
40
41 Some(self.dimensions.iter().product::<usize>() * element_size)
42 }
43
44 pub fn to_owned(&self) -> Shape<'static> {
45 let Shape {
46 element_type,
47 dimensions,
48 } = self;
49
50 Shape::new(*element_type, dimensions.clone().into_owned())
51 }
52}
53
54impl<'a> Display for Shape<'a> {
55 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
56 let Shape {
57 element_type,
58 dimensions,
59 } = self;
60 write!(f, "{}[", element_type)?;
61
62 for (i, dim) in dimensions.iter().enumerate() {
63 if i > 0 {
64 write!(f, ", ")?;
65 }
66
67 write!(f, "{}", dim)?;
68 }
69
70 write!(f, "]")?;
71 Ok(())
72 }
73}
74
75impl FromStr for Shape<'static> {
76 type Err = FormatError;
77
78 fn from_str(s: &str) -> Result<Self, Self::Err> {
79 let opening_bracket = s.find('[').ok_or(FormatError::Malformed)?;
80 let element_type = s[..opening_bracket].trim();
81 let ty = element_type.parse().map_err(|_| {
82 FormatError::UnknownElementType {
83 found: element_type.to_string(),
84 }
85 })?;
86
87 let closing_bracket = s.rfind(']').ok_or(FormatError::Malformed)?;
88
89 let between_brackets = &s[opening_bracket + 1..closing_bracket];
90
91 let mut dimensions = Vec::new();
92
93 for word in between_brackets.split(',') {
94 let word = word.trim();
95 let dimension = word.parse::<usize>().map_err(|e| {
96 FormatError::BadDimension {
97 found: word.to_string(),
98 reason: e,
99 }
100 })?;
101 dimensions.push(dimension);
102 }
103
104 Ok(Shape {
105 element_type: ty,
106 dimensions: dimensions.into(),
107 })
108 }
109}
110
111#[derive(Debug, Clone, PartialEq)]
112pub enum FormatError {
113 Malformed,
114 UnknownElementType {
115 found: String,
116 },
117 BadDimension {
118 found: String,
119 reason: ParseIntError,
120 },
121}
122
123impl Display for FormatError {
124 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
125 match self {
126 FormatError::Malformed => write!(f, "Malformed shape"),
127 FormatError::UnknownElementType { found } => {
128 write!(f, "Couldn't recognise the \"{}\" element type", found)
129 },
130 FormatError::BadDimension { found, .. } => {
131 write!(f, "\"{}\" isn't a valid dimension", found)
132 },
133 }
134 }
135}
136
137#[cfg(feature = "std")]
138impl std::error::Error for FormatError {
139 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
140 match self {
141 FormatError::BadDimension { reason, .. } => Some(reason),
142 _ => None,
143 }
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use std::prelude::v1::*;
151
152 const SHAPES: &[(Shape, &str)] = &[
153 (
154 Shape {
155 element_type: ElementType::F32,
156 dimensions: Cow::Borrowed(&[1, 2, 3]),
157 },
158 "f32[1, 2, 3]",
159 ),
160 (
161 Shape {
162 element_type: ElementType::U8,
163 dimensions: Cow::Borrowed(&[42]),
164 },
165 "u8[42]",
166 ),
167 ];
168
169 #[test]
170 fn shape_format() {
171 for (shape, should_be) in SHAPES.iter().cloned() {
172 let got = shape.to_string();
173 assert_eq!(got, should_be);
174 }
175 }
176
177 #[test]
178 fn parse() {
179 for (should_be, src) in SHAPES.iter().cloned() {
180 let got: Shape = src.parse().unwrap();
181 assert_eq!(got, should_be);
182 }
183 }
184}