1use std::panic::Location;
2
3use thiserror::Error;
4
5use crate::{shape::shape::Shape, strides::strides::Strides};
6
7#[derive(Debug, Error)]
9pub enum ShapeError {
10 #[error("Size mismatch: expected {expected}, got {actual} at {location}")]
12 SizeMismatch {
13 expected: i64,
15 actual: i64,
17 location: &'static Location<'static>,
19 },
20
21 #[error("Invalid dimension: {message} at {location}")]
23 InvalidDimension {
24 message: String,
26 location: &'static Location<'static>,
28 },
29
30 #[error("Broadcasting error: {message} at {location}")]
32 BroadcastError {
33 message: String,
35 location: &'static Location<'static>,
37 },
38
39 #[error("Matrix multiplication shape mismatch: lhs shape {lhs:?}, rhs shape {rhs:?}, expected rhs shape [{expected}, N] at {location}")]
41 MatmulMismatch {
42 lhs: Shape,
44 rhs: Shape,
46 expected: i64,
48 location: &'static Location<'static>,
50 },
51
52 #[error("Dimension mismatch: expected {expected}, got {actual} at {location}")]
54 DimMismatch {
55 expected: usize,
57 actual: usize,
59 location: &'static Location<'static>,
61 },
62
63 #[error("Dimension out of range: expected in {expected:?}, got {actual} at {location}")]
65 DimOutOfRange {
66 expected: std::ops::Range<i64>,
68 actual: i64,
70 location: &'static Location<'static>,
72 },
73
74 #[error("Geomspace error: start {start} and end {end} must have the same sign at {location}")]
76 GeomSpaceError {
77 start: f64,
79 end: f64,
81 location: &'static Location<'static>,
83 },
84
85 #[error("Concat dimension mismatch: expected {expected} but got {actual} at {location}")]
87 ConcatDimMismatch {
88 expected: usize,
90 actual: usize,
92 location: &'static Location<'static>,
94 },
95
96 #[error("Expected greater than {expected}, got {actual} at {location}")]
98 NdimNotEnough {
99 message: String,
101 expected: usize,
103 actual: usize,
105 location: &'static Location<'static>,
107 },
108
109 #[error("Squeeze error: axis {axis} is not 1, shape {shape}, at {location}")]
111 SqueezeError {
112 axis: usize,
114 shape: Shape,
116 location: &'static Location<'static>,
118 },
119
120 #[error("{message}Tensor is not contiguous, got shape {shape:?}, strides {strides:?}, at {location}")]
122 ContiguousError {
123 message: String,
125 shape: Shape,
127 strides: Strides,
129 location: &'static Location<'static>,
131 },
132
133 #[error("Conv error: {message} at {location}")]
135 ConvError {
136 message: String,
138 location: &'static Location<'static>,
140 },
141
142 #[error("Topk error: {message} at {location}")]
144 TopkError {
145 message: String,
147 location: &'static Location<'static>,
149 },
150
151 #[error("Inplace reshape error: {message} at {location}")]
153 InplaceReshapeError {
154 message: String,
156 location: &'static Location<'static>,
158 },
159
160 #[error("Expand error: dimention {old_dim} is not 1, at {location}")]
162 ExpandError {
163 old_dim: i64,
165 location: &'static Location<'static>,
167 },
168
169 #[error("Invalid shape: {message} at {location}")]
171 InvalidShape {
172 message: String,
174 location: &'static Location<'static>,
176 },
177}
178
179impl ShapeError {
180 #[track_caller]
182 pub fn check_matmul(lhs: &Shape, rhs: &Shape) -> Result<(), Self> {
183 let lhs_last = *lhs.last().expect("lhs shape is empty");
184 let rhs_last_sec = rhs[rhs.len() - 2];
185 if lhs_last != rhs_last_sec {
186 return Err(Self::MatmulMismatch {
187 lhs: lhs.clone(),
188 rhs: rhs.clone(),
189 expected: lhs_last,
190 location: Location::caller(),
191 });
192 }
193 Ok(())
194 }
195
196 #[track_caller]
198 pub fn check_dim(expected: usize, actual: usize) -> Result<(), Self> {
199 if expected != actual {
200 return Err(Self::DimMismatch {
201 expected,
202 actual,
203 location: Location::caller(),
204 });
205 }
206 Ok(())
207 }
208
209 #[track_caller]
211 pub fn check_ndim_enough(msg: String, expected: usize, actual: usize) -> Result<(), Self> {
212 if expected > actual {
213 return Err(Self::NdimNotEnough {
214 message: msg,
215 expected,
216 actual,
217 location: Location::caller(),
218 });
219 }
220 Ok(())
221 }
222
223 #[track_caller]
225 pub fn check_contiguous(
226 msg: String,
227 layout: &crate::layout::layout::Layout,
228 ) -> Result<(), Self> {
229 if !layout.is_contiguous() {
230 return Err(Self::ContiguousError {
231 message: msg,
232 shape: layout.shape().clone(),
233 strides: layout.strides().clone(),
234 location: Location::caller(),
235 });
236 }
237 Ok(())
238 }
239
240 #[track_caller]
242 pub fn check_size_match(expected: i64, actual: i64) -> Result<(), Self> {
243 if expected != actual {
244 return Err(Self::SizeMismatch {
245 expected,
246 actual,
247 location: Location::caller(),
248 });
249 }
250 Ok(())
251 }
252
253 #[track_caller]
255 pub fn check_inplace_out_layout_valid(
256 out_shape: &Shape,
257 inplace_layout: &crate::layout::layout::Layout,
258 ) -> Result<(), Self> {
259 Self::check_size_match(out_shape.size(), inplace_layout.size())?;
260 Self::check_contiguous(
261 "Method with out Tensor requires out Tensor to be contiguous. ".to_string(),
262 inplace_layout,
263 )?;
264 Ok(())
265 }
266
267 #[track_caller]
269 pub fn check_index_out_of_range(index: i64, dim: i64) -> Result<(), Self> {
270 if index >= dim || index < 0 {
271 return Err(Self::DimOutOfRange {
272 expected: 0..dim,
273 actual: index,
274 location: Location::caller(),
275 });
276 }
277
278 Ok(())
279 }
280}