candle_core/
error.rs

1//! Candle-specific Error and Result
2use std::{convert::Infallible, fmt::Display};
3
4use crate::{DType, DeviceLocation, Layout, MetalError, Shape};
5
6#[derive(Debug, Clone)]
7pub struct MatMulUnexpectedStriding {
8    pub lhs_l: Layout,
9    pub rhs_l: Layout,
10    pub bmnk: (usize, usize, usize, usize),
11    pub msg: &'static str,
12}
13
14impl std::fmt::Debug for Error {
15    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16        write!(f, "{self}")
17    }
18}
19
20/// Main library error type.
21#[derive(thiserror::Error)]
22pub enum Error {
23    // === DType Errors ===
24    #[error("{msg}, expected: {expected:?}, got: {got:?}")]
25    UnexpectedDType {
26        msg: &'static str,
27        expected: DType,
28        got: DType,
29    },
30
31    #[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
32    DTypeMismatchBinaryOp {
33        lhs: DType,
34        rhs: DType,
35        op: &'static str,
36    },
37
38    #[error("unsupported dtype {0:?} for op {1}")]
39    UnsupportedDTypeForOp(DType, &'static str),
40
41    // === Dimension Index Errors ===
42    #[error("{op}: dimension index {dim} out of range for shape {shape:?}")]
43    DimOutOfRange {
44        shape: Shape,
45        dim: i32,
46        op: &'static str,
47    },
48
49    #[error("{op}: duplicate dim index {dims:?} for shape {shape:?}")]
50    DuplicateDimIndex {
51        shape: Shape,
52        dims: Vec<usize>,
53        op: &'static str,
54    },
55
56    // === Shape Errors ===
57    #[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
58    UnexpectedNumberOfDims {
59        expected: usize,
60        got: usize,
61        shape: Shape,
62    },
63
64    #[error("{msg}, expected: {expected:?}, got: {got:?}")]
65    UnexpectedShape {
66        msg: String,
67        expected: Shape,
68        got: Shape,
69    },
70
71    #[error(
72        "Shape mismatch, got buffer of size {buffer_size} which is compatible with shape {shape:?}"
73    )]
74    ShapeMismatch { buffer_size: usize, shape: Shape },
75
76    #[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
77    ShapeMismatchBinaryOp {
78        lhs: Shape,
79        rhs: Shape,
80        op: &'static str,
81    },
82
83    #[error("shape mismatch in cat for dim {dim}, shape for arg 1: {first_shape:?} shape for arg {n}: {nth_shape:?}")]
84    ShapeMismatchCat {
85        dim: usize,
86        first_shape: Shape,
87        n: usize,
88        nth_shape: Shape,
89    },
90
91    #[error("Cannot divide tensor of shape {shape:?} equally along dim {dim} into {n_parts}")]
92    ShapeMismatchSplit {
93        shape: Shape,
94        dim: usize,
95        n_parts: usize,
96    },
97
98    #[error("{op} can only be performed on a single dimension")]
99    OnlySingleDimension { op: &'static str, dims: Vec<usize> },
100
101    #[error("empty tensor for {op}")]
102    EmptyTensor { op: &'static str },
103
104    // === Device Errors ===
105    #[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
106    DeviceMismatchBinaryOp {
107        lhs: DeviceLocation,
108        rhs: DeviceLocation,
109        op: &'static str,
110    },
111
112    // === Op Specific Errors ===
113    #[error("narrow invalid args {msg}: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
114    NarrowInvalidArgs {
115        shape: Shape,
116        dim: usize,
117        start: usize,
118        len: usize,
119        msg: &'static str,
120    },
121
122    #[error("conv1d invalid args {msg}: inp: {inp_shape:?}, k: {k_shape:?}, pad: {padding}, stride: {stride}")]
123    Conv1dInvalidArgs {
124        inp_shape: Shape,
125        k_shape: Shape,
126        padding: usize,
127        stride: usize,
128        msg: &'static str,
129    },
130
131    #[error("{op} invalid index {index} with dim size {size}")]
132    InvalidIndex {
133        op: &'static str,
134        index: usize,
135        size: usize,
136    },
137
138    #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
139    BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
140
141    #[error("cannot set variable {msg}")]
142    CannotSetVar { msg: &'static str },
143
144    // Box indirection to avoid large variant.
145    #[error("{0:?}")]
146    MatMulUnexpectedStriding(Box<MatMulUnexpectedStriding>),
147
148    #[error("{op} only supports contiguous tensors")]
149    RequiresContiguous { op: &'static str },
150
151    #[error("{op} expects at least one tensor")]
152    OpRequiresAtLeastOneTensor { op: &'static str },
153
154    #[error("{op} expects at least two tensors")]
155    OpRequiresAtLeastTwoTensors { op: &'static str },
156
157    #[error("backward is not supported for {op}")]
158    BackwardNotSupported { op: &'static str },
159
160    // === Other Errors ===
161    #[error("the candle crate has not been built with cuda support")]
162    NotCompiledWithCudaSupport,
163
164    #[error("the candle crate has not been built with metal support")]
165    NotCompiledWithMetalSupport,
166
167    #[error("cannot find tensor {path}")]
168    CannotFindTensor { path: String },
169
170    // === Wrapped Errors ===
171    #[error(transparent)]
172    Cuda(Box<dyn std::error::Error + Send + Sync>),
173
174    #[error("Metal error {0}")]
175    Metal(#[from] MetalError),
176
177    #[cfg(all(not(target_arch = "wasm32"), not(target_os = "ios"), feature = "ug"))]
178    #[error(transparent)]
179    Ug(#[from] candle_ug::Error),
180
181    #[error(transparent)]
182    TryFromIntError(#[from] core::num::TryFromIntError),
183
184    #[error("npy/npz error {0}")]
185    Npy(String),
186
187    /// Zip file format error.
188    #[error(transparent)]
189    Zip(#[from] zip::result::ZipError),
190
191    /// Integer parse error.
192    #[error(transparent)]
193    ParseInt(#[from] std::num::ParseIntError),
194
195    /// Utf8 parse error.
196    #[error(transparent)]
197    FromUtf8(#[from] std::string::FromUtf8Error),
198
199    /// I/O error.
200    #[error(transparent)]
201    Io(#[from] std::io::Error),
202
203    /// SafeTensor error.
204    #[error(transparent)]
205    SafeTensor(#[from] safetensors::SafeTensorError),
206
207    #[error("unsupported safetensor dtype {0:?}")]
208    UnsupportedSafeTensorDtype(safetensors::Dtype),
209
210    /// Arbitrary errors wrapping.
211    #[error("{0}")]
212    Wrapped(Box<dyn std::fmt::Display + Send + Sync>),
213
214    /// Arbitrary errors wrapping with context.
215    #[error("{wrapped:?}\n{context:?}")]
216    WrappedContext {
217        wrapped: Box<dyn std::error::Error + Send + Sync>,
218        context: String,
219    },
220
221    #[error("{context}\n{inner}")]
222    Context {
223        inner: Box<Self>,
224        context: Box<dyn std::fmt::Display + Send + Sync>,
225    },
226
227    /// Adding path information to an error.
228    #[error("path: {path:?} {inner}")]
229    WithPath {
230        inner: Box<Self>,
231        path: std::path::PathBuf,
232    },
233
234    #[error("{inner}\n{backtrace}")]
235    WithBacktrace {
236        inner: Box<Self>,
237        backtrace: Box<std::backtrace::Backtrace>,
238    },
239
240    /// User generated error message, typically created via `bail!`.
241    #[error("{0}")]
242    Msg(String),
243
244    #[error("unwrap none")]
245    UnwrapNone,
246}
247
248pub type Result<T> = std::result::Result<T, Error>;
249
250impl Error {
251    pub fn wrap(err: impl std::fmt::Display + Send + Sync + 'static) -> Self {
252        Self::Wrapped(Box::new(err)).bt()
253    }
254
255    pub fn msg(err: impl std::fmt::Display) -> Self {
256        Self::Msg(err.to_string()).bt()
257    }
258
259    pub fn debug(err: impl std::fmt::Debug) -> Self {
260        Self::Msg(format!("{err:?}")).bt()
261    }
262
263    pub fn bt(self) -> Self {
264        let backtrace = std::backtrace::Backtrace::capture();
265        match backtrace.status() {
266            std::backtrace::BacktraceStatus::Disabled
267            | std::backtrace::BacktraceStatus::Unsupported => self,
268            _ => Self::WithBacktrace {
269                inner: Box::new(self),
270                backtrace: Box::new(backtrace),
271            },
272        }
273    }
274
275    pub fn with_path<P: AsRef<std::path::Path>>(self, p: P) -> Self {
276        Self::WithPath {
277            inner: Box::new(self),
278            path: p.as_ref().to_path_buf(),
279        }
280    }
281
282    pub fn context(self, c: impl std::fmt::Display + Send + Sync + 'static) -> Self {
283        Self::Context {
284            inner: Box::new(self),
285            context: Box::new(c),
286        }
287    }
288}
289
290#[macro_export]
291macro_rules! bail {
292    ($msg:literal $(,)?) => {
293        return Err($crate::Error::Msg(format!($msg).into()).bt())
294    };
295    ($err:expr $(,)?) => {
296        return Err($crate::Error::Msg(format!($err).into()).bt())
297    };
298    ($fmt:expr, $($arg:tt)*) => {
299        return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt())
300    };
301}
302
303pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
304    match (r1, r2) {
305        (Ok(r1), Ok(r2)) => Ok((r1, r2)),
306        (Err(e), _) => Err(e),
307        (_, Err(e)) => Err(e),
308    }
309}
310
311pub(crate) mod private {
312    pub trait Sealed {}
313
314    impl<T, E> Sealed for std::result::Result<T, E> where E: std::error::Error {}
315    impl<T> Sealed for Option<T> {}
316}
317
318/// Attach more context to an error.
319///
320/// Inspired by [`anyhow::Context`].
321pub trait Context<T, E>: private::Sealed {
322    /// Wrap the error value with additional context.
323    fn context<C>(self, context: C) -> std::result::Result<T, Error>
324    where
325        C: Display + Send + Sync + 'static;
326
327    /// Wrap the error value with additional context that is evaluated lazily
328    /// only once an error does occur.
329    fn with_context<C, F>(self, f: F) -> std::result::Result<T, Error>
330    where
331        C: Display + Send + Sync + 'static,
332        F: FnOnce() -> C;
333}
334
335impl<T, E> Context<T, E> for std::result::Result<T, E>
336where
337    E: std::error::Error + Send + Sync + 'static,
338{
339    fn context<C>(self, context: C) -> std::result::Result<T, Error>
340    where
341        C: Display + Send + Sync + 'static,
342    {
343        // Not using map_err to save 2 useless frames off the captured backtrace
344        // in ext_context.
345        match self {
346            Ok(ok) => Ok(ok),
347            Err(error) => Err(Error::WrappedContext {
348                wrapped: Box::new(error),
349                context: context.to_string(),
350            }
351            .bt()),
352        }
353    }
354
355    fn with_context<C, F>(self, context: F) -> std::result::Result<T, Error>
356    where
357        C: Display + Send + Sync + 'static,
358        F: FnOnce() -> C,
359    {
360        match self {
361            Ok(ok) => Ok(ok),
362            Err(error) => Err(Error::WrappedContext {
363                wrapped: Box::new(error),
364                context: context().to_string(),
365            }
366            .bt()),
367        }
368    }
369}
370
371impl<T> Context<T, Infallible> for Option<T> {
372    fn context<C>(self, context: C) -> std::result::Result<T, Error>
373    where
374        C: Display + Send + Sync + 'static,
375    {
376        // Not using ok_or_else to save 2 useless frames off the captured
377        // backtrace.
378        match self {
379            Some(ok) => Ok(ok),
380            None => Err(Error::msg(context).bt()),
381        }
382    }
383
384    fn with_context<C, F>(self, context: F) -> std::result::Result<T, Error>
385    where
386        C: Display + Send + Sync + 'static,
387        F: FnOnce() -> C,
388    {
389        match self {
390            Some(v) => Ok(v),
391            None => Err(Error::UnwrapNone.context(context()).bt()),
392        }
393    }
394}