hpt_common/error/
shape.rs

1use std::panic::Location;
2
3use thiserror::Error;
4
5use crate::{shape::shape::Shape, strides::strides::Strides};
6
7/// Errors related to tensor shapes and dimensions
8#[derive(Debug, Error)]
9pub enum ShapeError {
10    /// Error that occurs when there is size problem
11    #[error("InvalidSize: {message} at {location}")]
12    InvalidSize {
13        /// Message describing the size mismatch
14        message: String,
15        /// Location where the error occurred
16        location: &'static Location<'static>,
17    },
18
19    /// Error that occurs when the dimension of a tensor is invalid
20    #[error("Invalid dimension: {message} at {location}")]
21    InvalidDimension {
22        /// Message describing the invalid dimension
23        message: String,
24        /// Location where the error occurred
25        location: &'static Location<'static>,
26    },
27
28    /// Error that occurs when broadcasting fails
29    #[error("Broadcasting error: {message} at {location}")]
30    BroadcastError {
31        /// Message describing the broadcasting failure
32        message: String,
33        /// Location where the error occurred
34        location: &'static Location<'static>,
35    },
36
37    /// Error that occurs when the shape of two tensors does not match for matrix multiplication
38    #[error("Matrix multiplication shape mismatch: lhs shape {lhs:?}, rhs shape {rhs:?}, expected rhs shape [{expected}, N] at {location}")]
39    MatmulMismatch {
40        /// Left-hand side shape
41        lhs: Shape,
42        /// Right-hand side shape
43        rhs: Shape,
44        /// Expected shape for the right-hand side
45        expected: i64,
46        /// Location where the error occurred
47        location: &'static Location<'static>,
48    },
49
50    /// Error that occurs when the dimension of two tensors does not match
51    #[error("Dimension mismatch: expected {expected}, got {actual} at {location}")]
52    DimMismatch {
53        /// Expected dimension
54        expected: usize,
55        /// Actual dimension
56        actual: usize,
57        /// Location where the error occurred
58        location: &'static Location<'static>,
59    },
60
61    /// Error that occurs when the dimension is out of range
62    #[error("Dimension out of range: expected in {expected:?}, got {actual} at {location}")]
63    DimOutOfRange {
64        /// Expected range
65        expected: std::ops::Range<i64>,
66        /// Actual dimension
67        actual: i64,
68        /// Location where the error occurred
69        location: &'static Location<'static>,
70    },
71
72    /// Error that occurs when geomspace parameters are invalid
73    #[error("Geomspace error: start {start} and end {end} must have the same sign at {location}")]
74    GeomSpaceError {
75        /// Start value
76        start: f64,
77        /// End value
78        end: f64,
79        /// Location where error occurred
80        location: &'static Location<'static>,
81    },
82
83    /// Error that occurs when concat dimensions don't match
84    #[error("Concat dimension mismatch: expected {expected} but got {actual} at {location}")]
85    ConcatDimMismatch {
86        /// Expected dimension
87        expected: usize,
88        /// Actual dimension
89        actual: usize,
90        /// Location where error occurred
91        location: &'static Location<'static>,
92    },
93
94    /// Error that occurs when the number of dimensions of a tensor is less than the expected value
95    #[error("Expected greater than {expected}, got {actual} at {location}")]
96    NdimNotEnough {
97        /// Message describing the error
98        message: String,
99        /// Expected dimension
100        expected: usize,
101        /// Actual dimension
102        actual: usize,
103        /// Location where error occurred
104        location: &'static Location<'static>,
105    },
106
107    /// Error that occurs when the axis is not 1
108    #[error("Squeeze error: axis {axis} is not 1, shape {shape}, at {location}")]
109    SqueezeError {
110        /// Axis that is not 1
111        axis: usize,
112        /// Shape of the tensor
113        shape: Shape,
114        /// Location where error occurred
115        location: &'static Location<'static>,
116    },
117
118    /// Error that occurs when the tensor is not contiguous
119    #[error("{message}Tensor is not contiguous, got shape {shape:?}, strides {strides:?}, at {location}")]
120    ContiguousError {
121        /// message
122        message: String,
123        /// Shape of the tensor
124        shape: Shape,
125        /// Strides of the tensor
126        strides: Strides,
127        /// Location where error occurred
128        location: &'static Location<'static>,
129    },
130
131    /// Error that occurs when the input shape is invalid for conv2d
132    #[error("Conv error: {message} at {location}")]
133    ConvError {
134        /// Message describing the invalid input shape
135        message: String,
136        /// Location where error occurred
137        location: &'static Location<'static>,
138    },
139
140    /// Error that occurs when the topk operation is invalid
141    #[error("Topk error: {message} at {location}")]
142    TopkError {
143        /// Message describing the invalid topk operation
144        message: String,
145        /// Location where error occurred
146        location: &'static Location<'static>,
147    },
148
149    /// Error that occurs when the inplace reshape is invalid
150    #[error("Inplace reshape error: {message} at {location}")]
151    InplaceReshapeError {
152        /// Message describing the invalid inplace reshape
153        message: String,
154        /// Location where error occurred
155        location: &'static Location<'static>,
156    },
157
158    /// Error that occurs when the dimention to expand is not 1
159    #[error("Expand error: dimention {old_dim} is not 1, at {location}")]
160    ExpandError {
161        /// Old dimention
162        old_dim: i64,
163        /// Location where error occurred
164        location: &'static Location<'static>,
165    },
166
167    /// Error that occurs when the shape is invalid
168    #[error("Invalid shape: {message} at {location}")]
169    InvalidShape {
170        /// Message describing the invalid shape
171        message: String,
172        /// Location where error occurred
173        location: &'static Location<'static>,
174    },
175}
176
177impl ShapeError {
178    /// Check if the shapes of two tensors match for matrix multiplication
179    #[track_caller]
180    pub fn check_matmul(lhs: &Shape, rhs: &Shape) -> Result<(), Self> {
181        let lhs_last = *lhs.last().expect("lhs shape is empty");
182        let rhs_last_sec = rhs[rhs.len() - 2];
183        if lhs_last != rhs_last_sec {
184            return Err(Self::MatmulMismatch {
185                lhs: lhs.clone(),
186                rhs: rhs.clone(),
187                expected: lhs_last,
188                location: Location::caller(),
189            });
190        }
191        Ok(())
192    }
193
194    /// Check if the dimensions of two tensors match
195    #[track_caller]
196    pub fn check_dim(expected: usize, actual: usize) -> Result<(), Self> {
197        if expected != actual {
198            return Err(Self::DimMismatch {
199                expected,
200                actual,
201                location: Location::caller(),
202            });
203        }
204        Ok(())
205    }
206
207    /// Check if the number of dimensions of a tensor is greater than the expected value
208    #[track_caller]
209    pub fn check_ndim_enough(msg: String, expected: usize, actual: usize) -> Result<(), Self> {
210        if expected > actual {
211            return Err(Self::NdimNotEnough {
212                message: msg,
213                expected,
214                actual,
215                location: Location::caller(),
216            });
217        }
218        Ok(())
219    }
220
221    /// Check if the tensor is contiguous
222    #[track_caller]
223    pub fn check_contiguous(
224        msg: String,
225        layout: &crate::layout::layout::Layout,
226    ) -> Result<(), Self> {
227        if !layout.is_contiguous() {
228            return Err(Self::ContiguousError {
229                message: msg,
230                shape: layout.shape().clone(),
231                strides: layout.strides().clone(),
232                location: Location::caller(),
233            });
234        }
235        Ok(())
236    }
237
238    /// Check if the size of two tensors match
239    #[track_caller]
240    pub fn check_size_match(expected: i64, actual: i64) -> Result<(), Self> {
241        if expected != actual {
242            return Err(Self::InvalidSize {
243                message: format!("Size mismatch: expected {}, got {}", expected, actual),
244                location: Location::caller(),
245            });
246        }
247        Ok(())
248    }
249
250    /// Check if the size of two tensors match
251    #[track_caller]
252    pub fn check_size_gt(expected: i64, actual: i64) -> Result<(), Self> {
253        if expected > actual {
254            return Err(Self::InvalidSize {
255                message: format!("expected size greater than {}, got {}", expected, actual),
256                location: Location::caller(),
257            });
258        }
259        Ok(())
260    }
261
262    /// Check if the output layout is valid for computation with inplace operation
263    #[track_caller]
264    pub fn check_inplace_out_layout_valid(
265        out_shape: &Shape,
266        inplace_layout: &crate::layout::layout::Layout,
267    ) -> Result<(), Self> {
268        Self::check_size_match(out_shape.size(), inplace_layout.size())?;
269        Self::check_contiguous(
270            "Method with out Tensor requires out Tensor to be contiguous. ".to_string(),
271            inplace_layout,
272        )?;
273        Ok(())
274    }
275
276    /// Check if the index is out of range
277    #[track_caller]
278    pub fn check_index_out_of_range(index: i64, dim: i64) -> Result<(), Self> {
279        if index >= dim || index < 0 {
280            return Err(Self::DimOutOfRange {
281                expected: 0..dim,
282                actual: index,
283                location: Location::caller(),
284            });
285        }
286
287        Ok(())
288    }
289}