Skip to main content

cubecl_zspace/
errors.rs

1//! # Common Burn Errors
2
3use alloc::string::String;
4use core::{
5    error::Error,
6    fmt::{Display, Formatter},
7    ops::Range,
8};
9
10use crate::{Shape, Strides};
11
12/// Describes the kind of an index.
13#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
14pub enum IndexKind {
15    /// The index of an element in a dimension.
16    Element,
17
18    /// The index of a dimension.
19    Dimension,
20}
21
22impl IndexKind {
23    /// Get the display name of the kind.
24    pub fn name(&self) -> &'static str {
25        match self {
26            IndexKind::Element => "element",
27            IndexKind::Dimension => "dimension",
28        }
29    }
30}
31
32/// Access Bounds Error.
33#[derive(Debug, PartialEq, Eq, Clone, Hash)]
34pub enum BoundsError {
35    /// Generic bounds error.
36    Generic(String),
37
38    /// Index out of bounds.
39    Index {
40        /// The kind of index that was out of bounds.
41        kind: IndexKind,
42
43        /// The index that was out of bounds.
44        index: isize,
45
46        /// The range of valid indices.
47        bounds: Range<isize>,
48    },
49}
50
51impl BoundsError {
52    /// Create a new index error.
53    pub fn index(kind: IndexKind, index: isize, bounds: Range<isize>) -> Self {
54        Self::Index {
55            kind,
56            index,
57            bounds,
58        }
59    }
60}
61
62impl core::fmt::Display for BoundsError {
63    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
64        match self {
65            Self::Generic(msg) => write!(f, "BoundsError: {}", msg),
66            Self::Index {
67                kind,
68                index,
69                bounds: range,
70            } => write!(
71                f,
72                "BoundsError: {} {} out of bounds: {:?}",
73                kind.name(),
74                index,
75                range
76            ),
77        }
78    }
79}
80
81impl core::error::Error for BoundsError {}
82
83/// Common Expression Error.
84#[derive(Debug, Clone, PartialEq, Eq)]
85pub enum ExpressionError {
86    /// Parse Error.
87    ParseError {
88        /// The error message.
89        message: String,
90        /// The source expression.
91        source: String,
92    },
93
94    /// Invalid Expression.
95    InvalidExpression {
96        /// The error message.
97        message: String,
98        /// The source expression.
99        source: String,
100    },
101}
102
103impl core::fmt::Display for ExpressionError {
104    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
105        match self {
106            Self::ParseError { message, source } => {
107                write!(f, "ExpressionError: ParseError: {} ({})", message, source)
108            }
109            Self::InvalidExpression { message, source } => write!(
110                f,
111                "ExpressionError: InvalidExpression: {} ({})",
112                message, source
113            ),
114        }
115    }
116}
117
118impl core::error::Error for ExpressionError {}
119
120impl ExpressionError {
121    /// Constructs a new [`ExpressionError::ParseError`].
122    ///
123    /// This function is a utility for creating instances where a parsing error needs to be represented,
124    /// encapsulating a descriptive error message and the source of the error.
125    ///
126    /// # Parameters
127    ///
128    /// - `message`: A value that can be converted into a `String`, representing a human-readable description
129    ///   of the parsing error.
130    /// - `source`: A value that can be converted into a `String`, typically identifying the origin or
131    ///   input that caused the parsing error.
132    pub fn parse_error(message: impl Into<String>, source: impl Into<String>) -> Self {
133        Self::ParseError {
134            message: message.into(),
135            source: source.into(),
136        }
137    }
138
139    /// Creates a new [`ExpressionError::InvalidExpression`].
140    ///
141    /// # Parameters
142    /// - `message`: A detailed message describing the nature of the invalid expression.
143    ///   Accepts any type that can be converted into a `String`.
144    /// - `source`: The source or context in which the invalid expression occurred.
145    ///   Accepts any type that can be converted into a `String`.
146    pub fn invalid_expression(message: impl Into<String>, source: impl Into<String>) -> Self {
147        Self::InvalidExpression {
148            message: message.into(),
149            source: source.into(),
150        }
151    }
152}
153
154/// Collected shape/stride record.
155///
156/// As this is used for error messages, there is no expectation that this is valid,
157/// or that the ranks match.
158#[derive(Debug, Clone, PartialEq)]
159pub struct StrideRecord {
160    pub shape: Shape,
161    pub strides: Strides,
162}
163
164impl StrideRecord {
165    /// Create a new [`StrideRecord`] from a slice of usize strides.
166    pub fn from_usize_strides(shape: &[usize], strides: &[usize]) -> StrideRecord {
167        StrideRecord {
168            shape: shape.into(),
169            strides: strides.iter().map(|s| *s as isize).collect(),
170        }
171    }
172
173    /// Create a new [`StrideRecord`] from a slice of isize strides.
174    pub fn from_isize_strides(shape: &[usize], strides: &[isize]) -> StrideRecord {
175        StrideRecord {
176            shape: shape.into(),
177            strides: strides.into(),
178        }
179    }
180}
181
182/// Error describing striding issues.
183#[derive(Debug, Clone, PartialEq)]
184pub enum StrideError {
185    /// The ranks of the shape and strides do not match.
186    MalformedRanks { record: StrideRecord },
187
188    /// This is an unsupported rank.
189    UnsupportedRank { rank: usize, record: StrideRecord },
190
191    /// The strides violate a constraint.
192    Invalid {
193        message: String,
194        record: StrideRecord,
195    },
196}
197
198impl Display for StrideError {
199    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
200        match self {
201            StrideError::MalformedRanks { record } => write!(f, "Malformed strides: {:?}", record),
202            StrideError::UnsupportedRank { rank, record } => {
203                write!(f, "Unsupported rank {}: {:?}", rank, record)
204            }
205            StrideError::Invalid { message, record } => {
206                write!(f, "Invalid strides: {}: {:?}", message, record)
207            }
208        }
209    }
210}
211
212impl Error for StrideError {}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use alloc::format;
218    use alloc::string::ToString;
219
220    #[test]
221    fn test_bounds_error_display() {
222        assert_eq!(
223            format!("{}", BoundsError::Generic("test".to_string())),
224            "BoundsError: test"
225        );
226        assert_eq!(
227            format!(
228                "{}",
229                BoundsError::Index {
230                    kind: IndexKind::Element,
231                    index: 1,
232                    bounds: 0..2
233                }
234            ),
235            "BoundsError: element 1 out of bounds: 0..2"
236        );
237    }
238
239    #[test]
240    fn test_parse_error() {
241        let err = ExpressionError::parse_error("test", "source");
242        assert_eq!(
243            format!("{:?}", err),
244            "ParseError { message: \"test\", source: \"source\" }"
245        );
246    }
247
248    #[test]
249    fn test_invalid_expression() {
250        let err = ExpressionError::invalid_expression("test", "source");
251        assert_eq!(
252            format!("{:?}", err),
253            "InvalidExpression { message: \"test\", source: \"source\" }"
254        );
255    }
256}