1use std::panic::Location;
2
3use thiserror::Error;
4
5use crate::{shape::shape::Shape, strides::strides::Strides};
6
7#[derive(Debug, Error)]
9pub enum ShapeError {
10 #[error("InvalidSize: {message} at {location}")]
12 InvalidSize {
13 message: String,
15 location: &'static Location<'static>,
17 },
18
19 #[error("Invalid dimension: {message} at {location}")]
21 InvalidDimension {
22 message: String,
24 location: &'static Location<'static>,
26 },
27
28 #[error("Broadcasting error: {message} at {location}")]
30 BroadcastError {
31 message: String,
33 location: &'static Location<'static>,
35 },
36
37 #[error("Matrix multiplication shape mismatch: lhs shape {lhs:?}, rhs shape {rhs:?}, expected rhs shape [{expected}, N] at {location}")]
39 MatmulMismatch {
40 lhs: Shape,
42 rhs: Shape,
44 expected: i64,
46 location: &'static Location<'static>,
48 },
49
50 #[error("Dimension mismatch: expected {expected}, got {actual} at {location}")]
52 DimMismatch {
53 expected: usize,
55 actual: usize,
57 location: &'static Location<'static>,
59 },
60
61 #[error("Dimension out of range: expected in {expected:?}, got {actual} at {location}")]
63 DimOutOfRange {
64 expected: std::ops::Range<i64>,
66 actual: i64,
68 location: &'static Location<'static>,
70 },
71
72 #[error("Geomspace error: start {start} and end {end} must have the same sign at {location}")]
74 GeomSpaceError {
75 start: f64,
77 end: f64,
79 location: &'static Location<'static>,
81 },
82
83 #[error("Concat dimension mismatch: expected {expected} but got {actual} at {location}")]
85 ConcatDimMismatch {
86 expected: usize,
88 actual: usize,
90 location: &'static Location<'static>,
92 },
93
94 #[error("Expected greater than {expected}, got {actual} at {location}")]
96 NdimNotEnough {
97 message: String,
99 expected: usize,
101 actual: usize,
103 location: &'static Location<'static>,
105 },
106
107 #[error("Squeeze error: axis {axis} is not 1, shape {shape}, at {location}")]
109 SqueezeError {
110 axis: usize,
112 shape: Shape,
114 location: &'static Location<'static>,
116 },
117
118 #[error("{message}Tensor is not contiguous, got shape {shape:?}, strides {strides:?}, at {location}")]
120 ContiguousError {
121 message: String,
123 shape: Shape,
125 strides: Strides,
127 location: &'static Location<'static>,
129 },
130
131 #[error("Conv error: {message} at {location}")]
133 ConvError {
134 message: String,
136 location: &'static Location<'static>,
138 },
139
140 #[error("Topk error: {message} at {location}")]
142 TopkError {
143 message: String,
145 location: &'static Location<'static>,
147 },
148
149 #[error("Inplace reshape error: {message} at {location}")]
151 InplaceReshapeError {
152 message: String,
154 location: &'static Location<'static>,
156 },
157
158 #[error("Expand error: dimention {old_dim} is not 1, at {location}")]
160 ExpandError {
161 old_dim: i64,
163 location: &'static Location<'static>,
165 },
166
167 #[error("Invalid shape: {message} at {location}")]
169 InvalidShape {
170 message: String,
172 location: &'static Location<'static>,
174 },
175}
176
177impl ShapeError {
178 #[track_caller]
180 pub fn check_matmul(lhs: &Shape, rhs: &Shape) -> Result<(), Self> {
181 let lhs_last = *lhs.last().expect("lhs shape is empty");
182 let rhs_last_sec = rhs[rhs.len() - 2];
183 if lhs_last != rhs_last_sec {
184 return Err(Self::MatmulMismatch {
185 lhs: lhs.clone(),
186 rhs: rhs.clone(),
187 expected: lhs_last,
188 location: Location::caller(),
189 });
190 }
191 Ok(())
192 }
193
194 #[track_caller]
196 pub fn check_dim(expected: usize, actual: usize) -> Result<(), Self> {
197 if expected != actual {
198 return Err(Self::DimMismatch {
199 expected,
200 actual,
201 location: Location::caller(),
202 });
203 }
204 Ok(())
205 }
206
207 #[track_caller]
209 pub fn check_ndim_enough(msg: String, expected: usize, actual: usize) -> Result<(), Self> {
210 if expected > actual {
211 return Err(Self::NdimNotEnough {
212 message: msg,
213 expected,
214 actual,
215 location: Location::caller(),
216 });
217 }
218 Ok(())
219 }
220
221 #[track_caller]
223 pub fn check_contiguous(
224 msg: String,
225 layout: &crate::layout::layout::Layout,
226 ) -> Result<(), Self> {
227 if !layout.is_contiguous() {
228 return Err(Self::ContiguousError {
229 message: msg,
230 shape: layout.shape().clone(),
231 strides: layout.strides().clone(),
232 location: Location::caller(),
233 });
234 }
235 Ok(())
236 }
237
238 #[track_caller]
240 pub fn check_size_match(expected: i64, actual: i64) -> Result<(), Self> {
241 if expected != actual {
242 return Err(Self::InvalidSize {
243 message: format!("Size mismatch: expected {}, got {}", expected, actual),
244 location: Location::caller(),
245 });
246 }
247 Ok(())
248 }
249
250 #[track_caller]
252 pub fn check_size_gt(expected: i64, actual: i64) -> Result<(), Self> {
253 if expected > actual {
254 return Err(Self::InvalidSize {
255 message: format!("expected size greater than {}, got {}", expected, actual),
256 location: Location::caller(),
257 });
258 }
259 Ok(())
260 }
261
262 #[track_caller]
264 pub fn check_inplace_out_layout_valid(
265 out_shape: &Shape,
266 inplace_layout: &crate::layout::layout::Layout,
267 ) -> Result<(), Self> {
268 Self::check_size_match(out_shape.size(), inplace_layout.size())?;
269 Self::check_contiguous(
270 "Method with out Tensor requires out Tensor to be contiguous. ".to_string(),
271 inplace_layout,
272 )?;
273 Ok(())
274 }
275
276 #[track_caller]
278 pub fn check_index_out_of_range(index: i64, dim: i64) -> Result<(), Self> {
279 if index >= dim || index < 0 {
280 return Err(Self::DimOutOfRange {
281 expected: 0..dim,
282 actual: index,
283 location: Location::caller(),
284 });
285 }
286
287 Ok(())
288 }
289}