1use alloc::string::String;
4use core::{
5 error::Error,
6 fmt::{Display, Formatter},
7 ops::Range,
8};
9
10use crate::{Shape, Strides};
11
12#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
14pub enum IndexKind {
15 Element,
17
18 Dimension,
20}
21
22impl IndexKind {
23 pub fn name(&self) -> &'static str {
25 match self {
26 IndexKind::Element => "element",
27 IndexKind::Dimension => "dimension",
28 }
29 }
30}
31
32#[derive(Debug, PartialEq, Eq, Clone, Hash)]
34pub enum BoundsError {
35 Generic(String),
37
38 Index {
40 kind: IndexKind,
42
43 index: isize,
45
46 bounds: Range<isize>,
48 },
49}
50
51impl BoundsError {
52 pub fn index(kind: IndexKind, index: isize, bounds: Range<isize>) -> Self {
54 Self::Index {
55 kind,
56 index,
57 bounds,
58 }
59 }
60}
61
62impl core::fmt::Display for BoundsError {
63 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
64 match self {
65 Self::Generic(msg) => write!(f, "BoundsError: {}", msg),
66 Self::Index {
67 kind,
68 index,
69 bounds: range,
70 } => write!(
71 f,
72 "BoundsError: {} {} out of bounds: {:?}",
73 kind.name(),
74 index,
75 range
76 ),
77 }
78 }
79}
80
81impl core::error::Error for BoundsError {}
82
83#[derive(Debug, Clone, PartialEq, Eq)]
85pub enum ExpressionError {
86 ParseError {
88 message: String,
90 source: String,
92 },
93
94 InvalidExpression {
96 message: String,
98 source: String,
100 },
101}
102
103impl core::fmt::Display for ExpressionError {
104 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
105 match self {
106 Self::ParseError { message, source } => {
107 write!(f, "ExpressionError: ParseError: {} ({})", message, source)
108 }
109 Self::InvalidExpression { message, source } => write!(
110 f,
111 "ExpressionError: InvalidExpression: {} ({})",
112 message, source
113 ),
114 }
115 }
116}
117
118impl core::error::Error for ExpressionError {}
119
120impl ExpressionError {
121 pub fn parse_error(message: impl Into<String>, source: impl Into<String>) -> Self {
133 Self::ParseError {
134 message: message.into(),
135 source: source.into(),
136 }
137 }
138
139 pub fn invalid_expression(message: impl Into<String>, source: impl Into<String>) -> Self {
147 Self::InvalidExpression {
148 message: message.into(),
149 source: source.into(),
150 }
151 }
152}
153
154#[derive(Debug, Clone, PartialEq)]
159pub struct StrideRecord {
160 pub shape: Shape,
161 pub strides: Strides,
162}
163
164impl StrideRecord {
165 pub fn from_usize_strides(shape: &[usize], strides: &[usize]) -> StrideRecord {
167 StrideRecord {
168 shape: shape.into(),
169 strides: strides.iter().map(|s| *s as isize).collect(),
170 }
171 }
172
173 pub fn from_isize_strides(shape: &[usize], strides: &[isize]) -> StrideRecord {
175 StrideRecord {
176 shape: shape.into(),
177 strides: strides.into(),
178 }
179 }
180}
181
182#[derive(Debug, Clone, PartialEq)]
184pub enum StrideError {
185 MalformedRanks { record: StrideRecord },
187
188 UnsupportedRank { rank: usize, record: StrideRecord },
190
191 Invalid {
193 message: String,
194 record: StrideRecord,
195 },
196}
197
198impl Display for StrideError {
199 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
200 match self {
201 StrideError::MalformedRanks { record } => write!(f, "Malformed strides: {:?}", record),
202 StrideError::UnsupportedRank { rank, record } => {
203 write!(f, "Unsupported rank {}: {:?}", rank, record)
204 }
205 StrideError::Invalid { message, record } => {
206 write!(f, "Invalid strides: {}: {:?}", message, record)
207 }
208 }
209 }
210}
211
212impl Error for StrideError {}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use alloc::format;
218 use alloc::string::ToString;
219
220 #[test]
221 fn test_bounds_error_display() {
222 assert_eq!(
223 format!("{}", BoundsError::Generic("test".to_string())),
224 "BoundsError: test"
225 );
226 assert_eq!(
227 format!(
228 "{}",
229 BoundsError::Index {
230 kind: IndexKind::Element,
231 index: 1,
232 bounds: 0..2
233 }
234 ),
235 "BoundsError: element 1 out of bounds: 0..2"
236 );
237 }
238
239 #[test]
240 fn test_parse_error() {
241 let err = ExpressionError::parse_error("test", "source");
242 assert_eq!(
243 format!("{:?}", err),
244 "ParseError { message: \"test\", source: \"source\" }"
245 );
246 }
247
248 #[test]
249 fn test_invalid_expression() {
250 let err = ExpressionError::invalid_expression("test", "source");
251 assert_eq!(
252 format!("{:?}", err),
253 "InvalidExpression { message: \"test\", source: \"source\" }"
254 );
255 }
256}