1use std::{convert::Infallible, fmt::Display};
3
4use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
5
6#[derive(Debug, Clone)]
7pub struct MatMulUnexpectedStriding {
8 pub lhs_l: Layout,
9 pub rhs_l: Layout,
10 pub bmnk: (usize, usize, usize, usize),
11 pub msg: &'static str,
12}
13
14impl std::fmt::Debug for Error {
15 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16 write!(f, "{self}")
17 }
18}
19
20#[derive(thiserror::Error)]
22pub enum Error {
23 #[error("{msg}, expected: {expected:?}, got: {got:?}")]
25 UnexpectedDType {
26 msg: &'static str,
27 expected: DType,
28 got: DType,
29 },
30
31 #[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
32 DTypeMismatchBinaryOp {
33 lhs: DType,
34 rhs: DType,
35 op: &'static str,
36 },
37
38 #[error("unsupported dtype {0:?} for op {1}")]
39 UnsupportedDTypeForOp(DType, &'static str),
40
41 #[error("{op}: dimension index {dim} out of range for shape {shape:?}")]
43 DimOutOfRange {
44 shape: Shape,
45 dim: i32,
46 op: &'static str,
47 },
48
49 #[error("{op}: duplicate dim index {dims:?} for shape {shape:?}")]
50 DuplicateDimIndex {
51 shape: Shape,
52 dims: Vec<usize>,
53 op: &'static str,
54 },
55
56 #[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
58 UnexpectedNumberOfDims {
59 expected: usize,
60 got: usize,
61 shape: Shape,
62 },
63
64 #[error("{msg}, expected: {expected:?}, got: {got:?}")]
65 UnexpectedShape {
66 msg: String,
67 expected: Shape,
68 got: Shape,
69 },
70
71 #[error(
72 "Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}"
73 )]
74 ShapeMismatch { buffer_size: usize, shape: Shape },
75
76 #[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
77 ShapeMismatchBinaryOp {
78 lhs: Shape,
79 rhs: Shape,
80 op: &'static str,
81 },
82
83 #[error("shape mismatch in cat for dim {dim}, shape for arg 1: {first_shape:?} shape for arg {n}: {nth_shape:?}")]
84 ShapeMismatchCat {
85 dim: usize,
86 first_shape: Shape,
87 n: usize,
88 nth_shape: Shape,
89 },
90
91 #[error("Cannot divide tensor of shape {shape:?} equally along dim {dim} into {n_parts}")]
92 ShapeMismatchSplit {
93 shape: Shape,
94 dim: usize,
95 n_parts: usize,
96 },
97
98 #[error("{op} can only be performed on a single dimension")]
99 OnlySingleDimension { op: &'static str, dims: Vec<usize> },
100
101 #[error("empty tensor for {op}")]
102 EmptyTensor { op: &'static str },
103
104 #[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
106 DeviceMismatchBinaryOp {
107 lhs: DeviceLocation,
108 rhs: DeviceLocation,
109 op: &'static str,
110 },
111
112 #[error("narrow invalid args {msg}: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
114 NarrowInvalidArgs {
115 shape: Shape,
116 dim: usize,
117 start: usize,
118 len: usize,
119 msg: &'static str,
120 },
121
122 #[error("conv1d invalid args {msg}: inp: {inp_shape:?}, k: {k_shape:?}, pad: {padding}, stride: {stride}")]
123 Conv1dInvalidArgs {
124 inp_shape: Shape,
125 k_shape: Shape,
126 padding: usize,
127 stride: usize,
128 msg: &'static str,
129 },
130
131 #[error("{op} invalid index {index} with dim size {size}")]
132 InvalidIndex {
133 op: &'static str,
134 index: usize,
135 size: usize,
136 },
137
138 #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
139 BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
140
141 #[error("cannot set variable {msg}")]
142 CannotSetVar { msg: &'static str },
143
144 #[error("{0:?}")]
146 MatMulUnexpectedStriding(Box<MatMulUnexpectedStriding>),
147
148 #[error("{op} only supports contiguous tensors")]
149 RequiresContiguous { op: &'static str },
150
151 #[error("{op} expects at least one tensor")]
152 OpRequiresAtLeastOneTensor { op: &'static str },
153
154 #[error("{op} expects at least two tensors")]
155 OpRequiresAtLeastTwoTensors { op: &'static str },
156
157 #[error("backward is not supported for {op}")]
158 BackwardNotSupported { op: &'static str },
159
160 #[error("the candle crate has not been built with cuda support")]
162 NotCompiledWithCudaSupport,
163
164 #[error("the candle crate has not been built with metal support")]
165 NotCompiledWithMetalSupport,
166
167 #[error("cannot find tensor {path}")]
168 CannotFindTensor { path: String },
169
170 #[error(transparent)]
172 Cuda(Box<dyn std::error::Error + Send + Sync>),
173
174 #[error("Metal error {0}")]
175 Metal(#[from] MetalError),
176
177 #[cfg(all(not(target_arch = "wasm32"), not(target_os = "ios"), feature = "ug"))]
178 #[error(transparent)]
179 Ug(#[from] candle_ug::Error),
180
181 #[error(transparent)]
182 TryFromIntError(#[from] core::num::TryFromIntError),
183
184 #[error("npy/npz error {0}")]
185 Npy(String),
186
187 #[error(transparent)]
189 Zip(#[from] zip::result::ZipError),
190
191 #[error(transparent)]
193 ParseInt(#[from] std::num::ParseIntError),
194
195 #[error(transparent)]
197 FromUtf8(#[from] std::string::FromUtf8Error),
198
199 #[error(transparent)]
201 Io(#[from] std::io::Error),
202
203 #[error(transparent)]
205 SafeTensor(#[from] safetensors::SafeTensorError),
206
207 #[error("unsupported safetensor dtype {0:?}")]
208 UnsupportedSafeTensorDtype(safetensors::Dtype),
209
210 #[error("{0}")]
212 Wrapped(Box<dyn std::fmt::Display + Send + Sync>),
213
214 #[error("{wrapped:?}\n{context:?}")]
216 WrappedContext {
217 wrapped: Box<dyn std::error::Error + Send + Sync>,
218 context: String,
219 },
220
221 #[error("{context}\n{inner}")]
222 Context {
223 inner: Box<Self>,
224 context: Box<dyn std::fmt::Display + Send + Sync>,
225 },
226
227 #[error("path: {path:?} {inner}")]
229 WithPath {
230 inner: Box<Self>,
231 path: std::path::PathBuf,
232 },
233
234 #[error("{inner}\n{backtrace}")]
235 WithBacktrace {
236 inner: Box<Self>,
237 backtrace: Box<std::backtrace::Backtrace>,
238 },
239
240 #[error("{0}")]
242 Msg(String),
243
244 #[error("unwrap none")]
245 UnwrapNone,
246}
247
248pub type Result<T> = std::result::Result<T, Error>;
249
250impl Error {
251 pub fn wrap(err: impl std::fmt::Display + Send + Sync + 'static) -> Self {
252 Self::Wrapped(Box::new(err)).bt()
253 }
254
255 pub fn msg(err: impl std::fmt::Display) -> Self {
256 Self::Msg(err.to_string()).bt()
257 }
258
259 pub fn debug(err: impl std::fmt::Debug) -> Self {
260 Self::Msg(format!("{err:?}")).bt()
261 }
262
263 pub fn bt(self) -> Self {
264 let backtrace = std::backtrace::Backtrace::capture();
265 match backtrace.status() {
266 std::backtrace::BacktraceStatus::Disabled
267 | std::backtrace::BacktraceStatus::Unsupported => self,
268 _ => Self::WithBacktrace {
269 inner: Box::new(self),
270 backtrace: Box::new(backtrace),
271 },
272 }
273 }
274
275 pub fn with_path<P: AsRef<std::path::Path>>(self, p: P) -> Self {
276 Self::WithPath {
277 inner: Box::new(self),
278 path: p.as_ref().to_path_buf(),
279 }
280 }
281
282 pub fn context(self, c: impl std::fmt::Display + Send + Sync + 'static) -> Self {
283 Self::Context {
284 inner: Box::new(self),
285 context: Box::new(c),
286 }
287 }
288}
289
290#[macro_export]
291macro_rules! bail {
292 ($msg:literal $(,)?) => {
293 return Err($crate::Error::Msg(format!($msg).into()).bt())
294 };
295 ($err:expr $(,)?) => {
296 return Err($crate::Error::Msg(format!($err).into()).bt())
297 };
298 ($fmt:expr, $($arg:tt)*) => {
299 return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt())
300 };
301}
302
303pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
304 match (r1, r2) {
305 (Ok(r1), Ok(r2)) => Ok((r1, r2)),
306 (Err(e), _) => Err(e),
307 (_, Err(e)) => Err(e),
308 }
309}
310
311pub(crate) mod private {
312 pub trait Sealed {}
313
314 impl<T, E> Sealed for std::result::Result<T, E> where E: std::error::Error {}
315 impl<T> Sealed for Option<T> {}
316}
317
318pub trait Context<T, E>: private::Sealed {
322 fn context<C>(self, context: C) -> std::result::Result<T, Error>
324 where
325 C: Display + Send + Sync + 'static;
326
327 fn with_context<C, F>(self, f: F) -> std::result::Result<T, Error>
330 where
331 C: Display + Send + Sync + 'static,
332 F: FnOnce() -> C;
333}
334
335impl<T, E> Context<T, E> for std::result::Result<T, E>
336where
337 E: std::error::Error + Send + Sync + 'static,
338{
339 fn context<C>(self, context: C) -> std::result::Result<T, Error>
340 where
341 C: Display + Send + Sync + 'static,
342 {
343 match self {
346 Ok(ok) => Ok(ok),
347 Err(error) => Err(Error::WrappedContext {
348 wrapped: Box::new(error),
349 context: context.to_string(),
350 }
351 .bt()),
352 }
353 }
354
355 fn with_context<C, F>(self, context: F) -> std::result::Result<T, Error>
356 where
357 C: Display + Send + Sync + 'static,
358 F: FnOnce() -> C,
359 {
360 match self {
361 Ok(ok) => Ok(ok),
362 Err(error) => Err(Error::WrappedContext {
363 wrapped: Box::new(error),
364 context: context().to_string(),
365 }
366 .bt()),
367 }
368 }
369}
370
371impl<T> Context<T, Infallible> for Option<T> {
372 fn context<C>(self, context: C) -> std::result::Result<T, Error>
373 where
374 C: Display + Send + Sync + 'static,
375 {
376 match self {
379 Some(ok) => Ok(ok),
380 None => Err(Error::msg(context).bt()),
381 }
382 }
383
384 fn with_context<C, F>(self, context: F) -> std::result::Result<T, Error>
385 where
386 C: Display + Send + Sync + 'static,
387 F: FnOnce() -> C,
388 {
389 match self {
390 Some(v) => Ok(v),
391 None => Err(Error::UnwrapNone.context(context()).bt()),
392 }
393 }
394}