constensor_core/
error.rs

1use std::{convert::Infallible, fmt::Display};
2
3#[derive(thiserror::Error, Debug)]
4pub enum Error {
5    #[error(transparent)]
6    Cuda(Box<dyn std::error::Error + Send + Sync>),
7
8    #[error("Message: {0}")]
9    Msg(String),
10
11    #[error("{inner}\n{backtrace}")]
12    WithBacktrace {
13        inner: Box<Self>,
14        backtrace: Box<std::backtrace::Backtrace>,
15    },
16
17    #[error("IO error: {0}")]
18    IoError(String),
19
20    /// Arbitrary errors wrapping.
21    #[error(transparent)]
22    Wrapped(Box<dyn std::error::Error + Send + Sync>),
23
24    /// Arbitrary errors wrapping with context.
25    #[error("{wrapped:?}\n{context:?}")]
26    WrappedContext {
27        wrapped: Box<dyn std::error::Error + Send + Sync>,
28        context: String,
29    },
30
31    #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} ostride: {out_stride:?} mnk: {mnk:?}")]
32    MatMulNonContiguous {
33        lhs_stride: Vec<usize>,
34        rhs_stride: Vec<usize>,
35        out_stride: Vec<usize>,
36        mnk: (usize, usize, usize),
37    },
38}
39
40pub type Result<T> = std::result::Result<T, Error>;
41
42impl Error {
43    /// Create a new error based on a printable error message.
44    pub fn msg<M: Display>(msg: M) -> Self {
45        Self::Msg(msg.to_string()).bt()
46    }
47
48    pub fn bt(self) -> Self {
49        let backtrace = std::backtrace::Backtrace::capture();
50        match backtrace.status() {
51            std::backtrace::BacktraceStatus::Disabled
52            | std::backtrace::BacktraceStatus::Unsupported => self,
53            _ => Self::WithBacktrace {
54                inner: Box::new(self),
55                backtrace: Box::new(backtrace),
56            },
57        }
58    }
59}
60
61impl From<std::io::Error> for Error {
62    fn from(value: std::io::Error) -> Self {
63        Error::IoError(value.to_string())
64    }
65}
66
67#[macro_export]
68macro_rules! bail {
69    ($msg:literal $(,)?) => {
70        return Err($crate::Error::Msg(format!($msg).into()).bt())
71    };
72    ($err:expr $(,)?) => {
73        return Err($crate::Error::Msg(format!($err).into()).bt())
74    };
75    ($fmt:expr, $($arg:tt)*) => {
76        return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt())
77    };
78}
79
80pub(crate) mod private {
81    pub trait Sealed {}
82
83    impl<T, E> Sealed for std::result::Result<T, E> where E: std::error::Error {}
84    impl<T> Sealed for Option<T> {}
85}
86
87/// Attach more context to an error.
88///
89/// Inspired by `anyhow::Context`.
90pub trait Context<T, E>: private::Sealed {
91    /// Wrap the error value with additional context.
92    fn context<C>(self, context: C) -> std::result::Result<T, Error>
93    where
94        C: Display + Send + Sync + 'static;
95
96    /// Wrap the error value with additional context that is evaluated lazily
97    /// only once an error does occur.
98    fn with_context<C, F>(self, f: F) -> std::result::Result<T, Error>
99    where
100        C: Display + Send + Sync + 'static,
101        F: FnOnce() -> C;
102}
103
104impl<T, E> Context<T, E> for std::result::Result<T, E>
105where
106    E: std::error::Error + Send + Sync + 'static,
107{
108    fn context<C>(self, context: C) -> std::result::Result<T, Error>
109    where
110        C: Display + Send + Sync + 'static,
111    {
112        // Not using map_err to save 2 useless frames off the captured backtrace
113        // in ext_context.
114        match self {
115            Ok(ok) => Ok(ok),
116            Err(error) => Err(Error::WrappedContext {
117                wrapped: Box::new(error),
118                context: context.to_string(),
119            }),
120        }
121    }
122
123    fn with_context<C, F>(self, context: F) -> std::result::Result<T, Error>
124    where
125        C: Display + Send + Sync + 'static,
126        F: FnOnce() -> C,
127    {
128        match self {
129            Ok(ok) => Ok(ok),
130            Err(error) => Err(Error::WrappedContext {
131                wrapped: Box::new(error),
132                context: context().to_string(),
133            }),
134        }
135    }
136}
137
138impl<T> Context<T, Infallible> for Option<T> {
139    fn context<C>(self, context: C) -> std::result::Result<T, Error>
140    where
141        C: Display + Send + Sync + 'static,
142    {
143        // Not using ok_or_else to save 2 useless frames off the captured
144        // backtrace.
145        match self {
146            Some(ok) => Ok(ok),
147            None => Err(Error::msg(context)),
148        }
149    }
150
151    fn with_context<C, F>(self, context: F) -> std::result::Result<T, Error>
152    where
153        C: Display + Send + Sync + 'static,
154        F: FnOnce() -> C,
155    {
156        match self {
157            Some(ok) => Ok(ok),
158            None => Err(Error::msg(context())),
159        }
160    }
161}