hpt_common/error/
shape.rs

1use std::panic::Location;
2
3use thiserror::Error;
4
5use crate::{shape::shape::Shape, strides::strides::Strides};
6
7/// Errors related to tensor shapes and dimensions
8#[derive(Debug, Error)]
9pub enum ShapeError {
10    /// Error that occurs when the size of two tensors does not match
11    #[error("Size mismatch: expected {expected}, got {actual} at {location}")]
12    SizeMismatch {
13        /// Expected size
14        expected: i64,
15        /// Actual size
16        actual: i64,
17        /// Location where the error occurred
18        location: &'static Location<'static>,
19    },
20
21    /// Error that occurs when the dimension of a tensor is invalid
22    #[error("Invalid dimension: {message} at {location}")]
23    InvalidDimension {
24        /// Message describing the invalid dimension
25        message: String,
26        /// Location where the error occurred
27        location: &'static Location<'static>,
28    },
29
30    /// Error that occurs when broadcasting fails
31    #[error("Broadcasting error: {message} at {location}")]
32    BroadcastError {
33        /// Message describing the broadcasting failure
34        message: String,
35        /// Location where the error occurred
36        location: &'static Location<'static>,
37    },
38
39    /// Error that occurs when the shape of two tensors does not match for matrix multiplication
40    #[error("Matrix multiplication shape mismatch: lhs shape {lhs:?}, rhs shape {rhs:?}, expected rhs shape [{expected}, N] at {location}")]
41    MatmulMismatch {
42        /// Left-hand side shape
43        lhs: Shape,
44        /// Right-hand side shape
45        rhs: Shape,
46        /// Expected shape for the right-hand side
47        expected: i64,
48        /// Location where the error occurred
49        location: &'static Location<'static>,
50    },
51
52    /// Error that occurs when the dimension of two tensors does not match
53    #[error("Dimension mismatch: expected {expected}, got {actual} at {location}")]
54    DimMismatch {
55        /// Expected dimension
56        expected: usize,
57        /// Actual dimension
58        actual: usize,
59        /// Location where the error occurred
60        location: &'static Location<'static>,
61    },
62
63    /// Error that occurs when the dimension is out of range
64    #[error("Dimension out of range: expected in {expected:?}, got {actual} at {location}")]
65    DimOutOfRange {
66        /// Expected range
67        expected: std::ops::Range<i64>,
68        /// Actual dimension
69        actual: i64,
70        /// Location where the error occurred
71        location: &'static Location<'static>,
72    },
73
74    /// Error that occurs when geomspace parameters are invalid
75    #[error("Geomspace error: start {start} and end {end} must have the same sign at {location}")]
76    GeomSpaceError {
77        /// Start value
78        start: f64,
79        /// End value
80        end: f64,
81        /// Location where error occurred
82        location: &'static Location<'static>,
83    },
84
85    /// Error that occurs when concat dimensions don't match
86    #[error("Concat dimension mismatch: expected {expected} but got {actual} at {location}")]
87    ConcatDimMismatch {
88        /// Expected dimension
89        expected: usize,
90        /// Actual dimension
91        actual: usize,
92        /// Location where error occurred
93        location: &'static Location<'static>,
94    },
95
96    /// Error that occurs when the number of dimensions of a tensor is less than the expected value
97    #[error("Expected greater than {expected}, got {actual} at {location}")]
98    NdimNotEnough {
99        /// Message describing the error
100        message: String,
101        /// Expected dimension
102        expected: usize,
103        /// Actual dimension
104        actual: usize,
105        /// Location where error occurred
106        location: &'static Location<'static>,
107    },
108
109    /// Error that occurs when the axis is not 1
110    #[error("Squeeze error: axis {axis} is not 1, shape {shape}, at {location}")]
111    SqueezeError {
112        /// Axis that is not 1
113        axis: usize,
114        /// Shape of the tensor
115        shape: Shape,
116        /// Location where error occurred
117        location: &'static Location<'static>,
118    },
119
120    /// Error that occurs when the tensor is not contiguous
121    #[error("{message}Tensor is not contiguous, got shape {shape:?}, strides {strides:?}, at {location}")]
122    ContiguousError {
123        /// message
124        message: String,
125        /// Shape of the tensor
126        shape: Shape,
127        /// Strides of the tensor
128        strides: Strides,
129        /// Location where error occurred
130        location: &'static Location<'static>,
131    },
132
133    /// Error that occurs when the input shape is invalid for conv2d
134    #[error("Conv error: {message} at {location}")]
135    ConvError {
136        /// Message describing the invalid input shape
137        message: String,
138        /// Location where error occurred
139        location: &'static Location<'static>,
140    },
141
142    /// Error that occurs when the topk operation is invalid
143    #[error("Topk error: {message} at {location}")]
144    TopkError {
145        /// Message describing the invalid topk operation
146        message: String,
147        /// Location where error occurred
148        location: &'static Location<'static>,
149    },
150
151    /// Error that occurs when the inplace reshape is invalid
152    #[error("Inplace reshape error: {message} at {location}")]
153    InplaceReshapeError {
154        /// Message describing the invalid inplace reshape
155        message: String,
156        /// Location where error occurred
157        location: &'static Location<'static>,
158    },
159
160    /// Error that occurs when the dimention to expand is not 1
161    #[error("Expand error: dimention {old_dim} is not 1, at {location}")]
162    ExpandError {
163        /// Old dimention
164        old_dim: i64,
165        /// Location where error occurred
166        location: &'static Location<'static>,
167    },
168
169    /// Error that occurs when the shape is invalid
170    #[error("Invalid shape: {message} at {location}")]
171    InvalidShape {
172        /// Message describing the invalid shape
173        message: String,
174        /// Location where error occurred
175        location: &'static Location<'static>,
176    },
177}
178
179impl ShapeError {
180    /// Check if the shapes of two tensors match for matrix multiplication
181    #[track_caller]
182    pub fn check_matmul(lhs: &Shape, rhs: &Shape) -> Result<(), Self> {
183        let lhs_last = *lhs.last().expect("lhs shape is empty");
184        let rhs_last_sec = rhs[rhs.len() - 2];
185        if lhs_last != rhs_last_sec {
186            return Err(Self::MatmulMismatch {
187                lhs: lhs.clone(),
188                rhs: rhs.clone(),
189                expected: lhs_last,
190                location: Location::caller(),
191            });
192        }
193        Ok(())
194    }
195
196    /// Check if the dimensions of two tensors match
197    #[track_caller]
198    pub fn check_dim(expected: usize, actual: usize) -> Result<(), Self> {
199        if expected != actual {
200            return Err(Self::DimMismatch {
201                expected,
202                actual,
203                location: Location::caller(),
204            });
205        }
206        Ok(())
207    }
208
209    /// Check if the number of dimensions of a tensor is greater than the expected value
210    #[track_caller]
211    pub fn check_ndim_enough(msg: String, expected: usize, actual: usize) -> Result<(), Self> {
212        if expected > actual {
213            return Err(Self::NdimNotEnough {
214                message: msg,
215                expected,
216                actual,
217                location: Location::caller(),
218            });
219        }
220        Ok(())
221    }
222
223    /// Check if the tensor is contiguous
224    #[track_caller]
225    pub fn check_contiguous(
226        msg: String,
227        layout: &crate::layout::layout::Layout,
228    ) -> Result<(), Self> {
229        if !layout.is_contiguous() {
230            return Err(Self::ContiguousError {
231                message: msg,
232                shape: layout.shape().clone(),
233                strides: layout.strides().clone(),
234                location: Location::caller(),
235            });
236        }
237        Ok(())
238    }
239
240    /// Check if the size of two tensors match
241    #[track_caller]
242    pub fn check_size_match(expected: i64, actual: i64) -> Result<(), Self> {
243        if expected != actual {
244            return Err(Self::SizeMismatch {
245                expected,
246                actual,
247                location: Location::caller(),
248            });
249        }
250        Ok(())
251    }
252
253    /// Check if the output layout is valid for computation with inplace operation
254    #[track_caller]
255    pub fn check_inplace_out_layout_valid(
256        out_shape: &Shape,
257        inplace_layout: &crate::layout::layout::Layout,
258    ) -> Result<(), Self> {
259        Self::check_size_match(out_shape.size(), inplace_layout.size())?;
260        Self::check_contiguous(
261            "Method with out Tensor requires out Tensor to be contiguous. ".to_string(),
262            inplace_layout,
263        )?;
264        Ok(())
265    }
266
267    /// Check if the index is out of range
268    #[track_caller]
269    pub fn check_index_out_of_range(index: i64, dim: i64) -> Result<(), Self> {
270        if index >= dim || index < 0 {
271            return Err(Self::DimOutOfRange {
272                expected: 0..dim,
273                actual: index,
274                location: Location::caller(),
275            });
276        }
277
278        Ok(())
279    }
280}