candle_core/
error.rs

1//! Candle-specific Error and Result
2use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
3
4#[derive(Debug, Clone)]
5pub struct MatMulUnexpectedStriding {
6    pub lhs_l: Layout,
7    pub rhs_l: Layout,
8    pub bmnk: (usize, usize, usize, usize),
9    pub msg: &'static str,
10}
11
12impl std::fmt::Debug for Error {
13    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14        write!(f, "{self}")
15    }
16}
17
18/// Main library error type.
19#[derive(thiserror::Error)]
20pub enum Error {
21    // === DType Errors ===
22    #[error("{msg}, expected: {expected:?}, got: {got:?}")]
23    UnexpectedDType {
24        msg: &'static str,
25        expected: DType,
26        got: DType,
27    },
28
29    #[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
30    DTypeMismatchBinaryOp {
31        lhs: DType,
32        rhs: DType,
33        op: &'static str,
34    },
35
36    #[error("unsupported dtype {0:?} for op {1}")]
37    UnsupportedDTypeForOp(DType, &'static str),
38
39    // === Dimension Index Errors ===
40    #[error("{op}: dimension index {dim} out of range for shape {shape:?}")]
41    DimOutOfRange {
42        shape: Shape,
43        dim: i32,
44        op: &'static str,
45    },
46
47    #[error("{op}: duplicate dim index {dims:?} for shape {shape:?}")]
48    DuplicateDimIndex {
49        shape: Shape,
50        dims: Vec<usize>,
51        op: &'static str,
52    },
53
54    // === Shape Errors ===
55    #[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
56    UnexpectedNumberOfDims {
57        expected: usize,
58        got: usize,
59        shape: Shape,
60    },
61
62    #[error("{msg}, expected: {expected:?}, got: {got:?}")]
63    UnexpectedShape {
64        msg: String,
65        expected: Shape,
66        got: Shape,
67    },
68
69    #[error(
70        "Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}"
71    )]
72    ShapeMismatch { buffer_size: usize, shape: Shape },
73
74    #[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
75    ShapeMismatchBinaryOp {
76        lhs: Shape,
77        rhs: Shape,
78        op: &'static str,
79    },
80
81    #[error("shape mismatch in cat for dim {dim}, shape for arg 1: {first_shape:?} shape for arg {n}: {nth_shape:?}")]
82    ShapeMismatchCat {
83        dim: usize,
84        first_shape: Shape,
85        n: usize,
86        nth_shape: Shape,
87    },
88
89    #[error("Cannot divide tensor of shape {shape:?} equally along dim {dim} into {n_parts}")]
90    ShapeMismatchSplit {
91        shape: Shape,
92        dim: usize,
93        n_parts: usize,
94    },
95
96    #[error("{op} can only be performed on a single dimension")]
97    OnlySingleDimension { op: &'static str, dims: Vec<usize> },
98
99    #[error("empty tensor for {op}")]
100    EmptyTensor { op: &'static str },
101
102    // === Device Errors ===
103    #[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
104    DeviceMismatchBinaryOp {
105        lhs: DeviceLocation,
106        rhs: DeviceLocation,
107        op: &'static str,
108    },
109
110    // === Op Specific Errors ===
111    #[error("narrow invalid args {msg}: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
112    NarrowInvalidArgs {
113        shape: Shape,
114        dim: usize,
115        start: usize,
116        len: usize,
117        msg: &'static str,
118    },
119
120    #[error("conv1d invalid args {msg}: inp: {inp_shape:?}, k: {k_shape:?}, pad: {padding}, stride: {stride}")]
121    Conv1dInvalidArgs {
122        inp_shape: Shape,
123        k_shape: Shape,
124        padding: usize,
125        stride: usize,
126        msg: &'static str,
127    },
128
129    #[error("{op} invalid index {index} with dim size {size}")]
130    InvalidIndex {
131        op: &'static str,
132        index: usize,
133        size: usize,
134    },
135
136    #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
137    BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
138
139    #[error("cannot set variable {msg}")]
140    CannotSetVar { msg: &'static str },
141
142    // Box indirection to avoid large variant.
143    #[error("{0:?}")]
144    MatMulUnexpectedStriding(Box<MatMulUnexpectedStriding>),
145
146    #[error("{op} only supports contiguous tensors")]
147    RequiresContiguous { op: &'static str },
148
149    #[error("{op} expects at least one tensor")]
150    OpRequiresAtLeastOneTensor { op: &'static str },
151
152    #[error("{op} expects at least two tensors")]
153    OpRequiresAtLeastTwoTensors { op: &'static str },
154
155    #[error("backward is not supported for {op}")]
156    BackwardNotSupported { op: &'static str },
157
158    // === Other Errors ===
159    #[error("the candle crate has not been built with cuda support")]
160    NotCompiledWithCudaSupport,
161
162    #[error("the candle crate has not been built with metal support")]
163    NotCompiledWithMetalSupport,
164
165    #[error("cannot find tensor {path}")]
166    CannotFindTensor { path: String },
167
168    // === Wrapped Errors ===
169    #[error(transparent)]
170    Cuda(Box<dyn std::error::Error + Send + Sync>),
171
172    #[error("Metal error {0}")]
173    Metal(#[from] MetalError),
174
175    #[cfg(not(target_arch = "wasm32"))]
176    #[error(transparent)]
177    Ug(#[from] ug::Error),
178
179    #[error(transparent)]
180    TryFromIntError(#[from] core::num::TryFromIntError),
181
182    #[error("npy/npz error {0}")]
183    Npy(String),
184
185    /// Zip file format error.
186    #[error(transparent)]
187    Zip(#[from] zip::result::ZipError),
188
189    /// Integer parse error.
190    #[error(transparent)]
191    ParseInt(#[from] std::num::ParseIntError),
192
193    /// Utf8 parse error.
194    #[error(transparent)]
195    FromUtf8(#[from] std::string::FromUtf8Error),
196
197    /// I/O error.
198    #[error(transparent)]
199    Io(#[from] std::io::Error),
200
201    /// SafeTensor error.
202    #[error(transparent)]
203    SafeTensor(#[from] safetensors::SafeTensorError),
204
205    #[error("unsupported safetensor dtype {0:?}")]
206    UnsupportedSafeTensorDtype(safetensors::Dtype),
207
208    /// Arbitrary errors wrapping.
209    #[error("{0}")]
210    Wrapped(Box<dyn std::fmt::Display + Send + Sync>),
211
212    #[error("{context}\n{inner}")]
213    Context {
214        inner: Box<Self>,
215        context: Box<dyn std::fmt::Display + Send + Sync>,
216    },
217
218    /// Adding path information to an error.
219    #[error("path: {path:?} {inner}")]
220    WithPath {
221        inner: Box<Self>,
222        path: std::path::PathBuf,
223    },
224
225    #[error("{inner}\n{backtrace}")]
226    WithBacktrace {
227        inner: Box<Self>,
228        backtrace: Box<std::backtrace::Backtrace>,
229    },
230
231    /// User generated error message, typically created via `bail!`.
232    #[error("{0}")]
233    Msg(String),
234
235    #[error("unwrap none")]
236    UnwrapNone,
237}
238
239pub type Result<T> = std::result::Result<T, Error>;
240
241impl Error {
242    pub fn wrap(err: impl std::fmt::Display + Send + Sync + 'static) -> Self {
243        Self::Wrapped(Box::new(err)).bt()
244    }
245
246    pub fn msg(err: impl std::fmt::Display) -> Self {
247        Self::Msg(err.to_string()).bt()
248    }
249
250    pub fn debug(err: impl std::fmt::Debug) -> Self {
251        Self::Msg(format!("{err:?}")).bt()
252    }
253
254    pub fn bt(self) -> Self {
255        let backtrace = std::backtrace::Backtrace::capture();
256        match backtrace.status() {
257            std::backtrace::BacktraceStatus::Disabled
258            | std::backtrace::BacktraceStatus::Unsupported => self,
259            _ => Self::WithBacktrace {
260                inner: Box::new(self),
261                backtrace: Box::new(backtrace),
262            },
263        }
264    }
265
266    pub fn with_path<P: AsRef<std::path::Path>>(self, p: P) -> Self {
267        Self::WithPath {
268            inner: Box::new(self),
269            path: p.as_ref().to_path_buf(),
270        }
271    }
272
273    pub fn context(self, c: impl std::fmt::Display + Send + Sync + 'static) -> Self {
274        Self::Context {
275            inner: Box::new(self),
276            context: Box::new(c),
277        }
278    }
279}
280
281#[macro_export]
282macro_rules! bail {
283    ($msg:literal $(,)?) => {
284        return Err($crate::Error::Msg(format!($msg).into()).bt())
285    };
286    ($err:expr $(,)?) => {
287        return Err($crate::Error::Msg(format!($err).into()).bt())
288    };
289    ($fmt:expr, $($arg:tt)*) => {
290        return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt())
291    };
292}
293
294pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
295    match (r1, r2) {
296        (Ok(r1), Ok(r2)) => Ok((r1, r2)),
297        (Err(e), _) => Err(e),
298        (_, Err(e)) => Err(e),
299    }
300}
301
302// Taken from anyhow.
303pub trait Context<T> {
304    /// Wrap the error value with additional context.
305    fn context<C>(self, context: C) -> Result<T>
306    where
307        C: std::fmt::Display + Send + Sync + 'static;
308
309    /// Wrap the error value with additional context that is evaluated lazily
310    /// only once an error does occur.
311    fn with_context<C, F>(self, f: F) -> Result<T>
312    where
313        C: std::fmt::Display + Send + Sync + 'static,
314        F: FnOnce() -> C;
315}
316
317impl<T> Context<T> for Option<T> {
318    fn context<C>(self, context: C) -> Result<T>
319    where
320        C: std::fmt::Display + Send + Sync + 'static,
321    {
322        match self {
323            Some(v) => Ok(v),
324            None => Err(Error::UnwrapNone.context(context).bt()),
325        }
326    }
327
328    fn with_context<C, F>(self, f: F) -> Result<T>
329    where
330        C: std::fmt::Display + Send + Sync + 'static,
331        F: FnOnce() -> C,
332    {
333        match self {
334            Some(v) => Ok(v),
335            None => Err(Error::UnwrapNone.context(f()).bt()),
336        }
337    }
338}