1use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
3
4#[derive(Debug, Clone)]
5pub struct MatMulUnexpectedStriding {
6 pub lhs_l: Layout,
7 pub rhs_l: Layout,
8 pub bmnk: (usize, usize, usize, usize),
9 pub msg: &'static str,
10}
11
12impl std::fmt::Debug for Error {
13 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14 write!(f, "{self}")
15 }
16}
17
18#[derive(thiserror::Error)]
20pub enum Error {
21 #[error("{msg}, expected: {expected:?}, got: {got:?}")]
23 UnexpectedDType {
24 msg: &'static str,
25 expected: DType,
26 got: DType,
27 },
28
29 #[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
30 DTypeMismatchBinaryOp {
31 lhs: DType,
32 rhs: DType,
33 op: &'static str,
34 },
35
36 #[error("unsupported dtype {0:?} for op {1}")]
37 UnsupportedDTypeForOp(DType, &'static str),
38
39 #[error("{op}: dimension index {dim} out of range for shape {shape:?}")]
41 DimOutOfRange {
42 shape: Shape,
43 dim: i32,
44 op: &'static str,
45 },
46
47 #[error("{op}: duplicate dim index {dims:?} for shape {shape:?}")]
48 DuplicateDimIndex {
49 shape: Shape,
50 dims: Vec<usize>,
51 op: &'static str,
52 },
53
54 #[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
56 UnexpectedNumberOfDims {
57 expected: usize,
58 got: usize,
59 shape: Shape,
60 },
61
62 #[error("{msg}, expected: {expected:?}, got: {got:?}")]
63 UnexpectedShape {
64 msg: String,
65 expected: Shape,
66 got: Shape,
67 },
68
69 #[error(
70 "Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}"
71 )]
72 ShapeMismatch { buffer_size: usize, shape: Shape },
73
74 #[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
75 ShapeMismatchBinaryOp {
76 lhs: Shape,
77 rhs: Shape,
78 op: &'static str,
79 },
80
81 #[error("shape mismatch in cat for dim {dim}, shape for arg 1: {first_shape:?} shape for arg {n}: {nth_shape:?}")]
82 ShapeMismatchCat {
83 dim: usize,
84 first_shape: Shape,
85 n: usize,
86 nth_shape: Shape,
87 },
88
89 #[error("Cannot divide tensor of shape {shape:?} equally along dim {dim} into {n_parts}")]
90 ShapeMismatchSplit {
91 shape: Shape,
92 dim: usize,
93 n_parts: usize,
94 },
95
96 #[error("{op} can only be performed on a single dimension")]
97 OnlySingleDimension { op: &'static str, dims: Vec<usize> },
98
99 #[error("empty tensor for {op}")]
100 EmptyTensor { op: &'static str },
101
102 #[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
104 DeviceMismatchBinaryOp {
105 lhs: DeviceLocation,
106 rhs: DeviceLocation,
107 op: &'static str,
108 },
109
110 #[error("narrow invalid args {msg}: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
112 NarrowInvalidArgs {
113 shape: Shape,
114 dim: usize,
115 start: usize,
116 len: usize,
117 msg: &'static str,
118 },
119
120 #[error("conv1d invalid args {msg}: inp: {inp_shape:?}, k: {k_shape:?}, pad: {padding}, stride: {stride}")]
121 Conv1dInvalidArgs {
122 inp_shape: Shape,
123 k_shape: Shape,
124 padding: usize,
125 stride: usize,
126 msg: &'static str,
127 },
128
129 #[error("{op} invalid index {index} with dim size {size}")]
130 InvalidIndex {
131 op: &'static str,
132 index: usize,
133 size: usize,
134 },
135
136 #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
137 BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
138
139 #[error("cannot set variable {msg}")]
140 CannotSetVar { msg: &'static str },
141
142 #[error("{0:?}")]
144 MatMulUnexpectedStriding(Box<MatMulUnexpectedStriding>),
145
146 #[error("{op} only supports contiguous tensors")]
147 RequiresContiguous { op: &'static str },
148
149 #[error("{op} expects at least one tensor")]
150 OpRequiresAtLeastOneTensor { op: &'static str },
151
152 #[error("{op} expects at least two tensors")]
153 OpRequiresAtLeastTwoTensors { op: &'static str },
154
155 #[error("backward is not supported for {op}")]
156 BackwardNotSupported { op: &'static str },
157
158 #[error("the candle crate has not been built with cuda support")]
160 NotCompiledWithCudaSupport,
161
162 #[error("the candle crate has not been built with metal support")]
163 NotCompiledWithMetalSupport,
164
165 #[error("cannot find tensor {path}")]
166 CannotFindTensor { path: String },
167
168 #[error(transparent)]
170 Cuda(Box<dyn std::error::Error + Send + Sync>),
171
172 #[error("Metal error {0}")]
173 Metal(#[from] MetalError),
174
175 #[cfg(not(target_arch = "wasm32"))]
176 #[error(transparent)]
177 Ug(#[from] ug::Error),
178
179 #[error(transparent)]
180 TryFromIntError(#[from] core::num::TryFromIntError),
181
182 #[error("npy/npz error {0}")]
183 Npy(String),
184
185 #[error(transparent)]
187 Zip(#[from] zip::result::ZipError),
188
189 #[error(transparent)]
191 ParseInt(#[from] std::num::ParseIntError),
192
193 #[error(transparent)]
195 FromUtf8(#[from] std::string::FromUtf8Error),
196
197 #[error(transparent)]
199 Io(#[from] std::io::Error),
200
201 #[error(transparent)]
203 SafeTensor(#[from] safetensors::SafeTensorError),
204
205 #[error("unsupported safetensor dtype {0:?}")]
206 UnsupportedSafeTensorDtype(safetensors::Dtype),
207
208 #[error("{0}")]
210 Wrapped(Box<dyn std::fmt::Display + Send + Sync>),
211
212 #[error("{context}\n{inner}")]
213 Context {
214 inner: Box<Self>,
215 context: Box<dyn std::fmt::Display + Send + Sync>,
216 },
217
218 #[error("path: {path:?} {inner}")]
220 WithPath {
221 inner: Box<Self>,
222 path: std::path::PathBuf,
223 },
224
225 #[error("{inner}\n{backtrace}")]
226 WithBacktrace {
227 inner: Box<Self>,
228 backtrace: Box<std::backtrace::Backtrace>,
229 },
230
231 #[error("{0}")]
233 Msg(String),
234
235 #[error("unwrap none")]
236 UnwrapNone,
237}
238
239pub type Result<T> = std::result::Result<T, Error>;
240
241impl Error {
242 pub fn wrap(err: impl std::fmt::Display + Send + Sync + 'static) -> Self {
243 Self::Wrapped(Box::new(err)).bt()
244 }
245
246 pub fn msg(err: impl std::fmt::Display) -> Self {
247 Self::Msg(err.to_string()).bt()
248 }
249
250 pub fn debug(err: impl std::fmt::Debug) -> Self {
251 Self::Msg(format!("{err:?}")).bt()
252 }
253
254 pub fn bt(self) -> Self {
255 let backtrace = std::backtrace::Backtrace::capture();
256 match backtrace.status() {
257 std::backtrace::BacktraceStatus::Disabled
258 | std::backtrace::BacktraceStatus::Unsupported => self,
259 _ => Self::WithBacktrace {
260 inner: Box::new(self),
261 backtrace: Box::new(backtrace),
262 },
263 }
264 }
265
266 pub fn with_path<P: AsRef<std::path::Path>>(self, p: P) -> Self {
267 Self::WithPath {
268 inner: Box::new(self),
269 path: p.as_ref().to_path_buf(),
270 }
271 }
272
273 pub fn context(self, c: impl std::fmt::Display + Send + Sync + 'static) -> Self {
274 Self::Context {
275 inner: Box::new(self),
276 context: Box::new(c),
277 }
278 }
279}
280
281#[macro_export]
282macro_rules! bail {
283 ($msg:literal $(,)?) => {
284 return Err($crate::Error::Msg(format!($msg).into()).bt())
285 };
286 ($err:expr $(,)?) => {
287 return Err($crate::Error::Msg(format!($err).into()).bt())
288 };
289 ($fmt:expr, $($arg:tt)*) => {
290 return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt())
291 };
292}
293
294pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
295 match (r1, r2) {
296 (Ok(r1), Ok(r2)) => Ok((r1, r2)),
297 (Err(e), _) => Err(e),
298 (_, Err(e)) => Err(e),
299 }
300}
301
302pub trait Context<T> {
304 fn context<C>(self, context: C) -> Result<T>
306 where
307 C: std::fmt::Display + Send + Sync + 'static;
308
309 fn with_context<C, F>(self, f: F) -> Result<T>
312 where
313 C: std::fmt::Display + Send + Sync + 'static,
314 F: FnOnce() -> C;
315}
316
317impl<T> Context<T> for Option<T> {
318 fn context<C>(self, context: C) -> Result<T>
319 where
320 C: std::fmt::Display + Send + Sync + 'static,
321 {
322 match self {
323 Some(v) => Ok(v),
324 None => Err(Error::UnwrapNone.context(context).bt()),
325 }
326 }
327
328 fn with_context<C, F>(self, f: F) -> Result<T>
329 where
330 C: std::fmt::Display + Send + Sync + 'static,
331 F: FnOnce() -> C,
332 {
333 match self {
334 Some(v) => Ok(v),
335 None => Err(Error::UnwrapNone.context(f()).bt()),
336 }
337 }
338}