1use std::{convert::Infallible, fmt::Display};
2
3#[derive(thiserror::Error, Debug)]
4pub enum Error {
5 #[error(transparent)]
6 Cuda(Box<dyn std::error::Error + Send + Sync>),
7
8 #[error("Message: {0}")]
9 Msg(String),
10
11 #[error("{inner}\n{backtrace}")]
12 WithBacktrace {
13 inner: Box<Self>,
14 backtrace: Box<std::backtrace::Backtrace>,
15 },
16
17 #[error("IO error: {0}")]
18 IoError(String),
19
20 #[error(transparent)]
22 Wrapped(Box<dyn std::error::Error + Send + Sync>),
23
24 #[error("{wrapped:?}\n{context:?}")]
26 WrappedContext {
27 wrapped: Box<dyn std::error::Error + Send + Sync>,
28 context: String,
29 },
30
31 #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} ostride: {out_stride:?} mnk: {mnk:?}")]
32 MatMulNonContiguous {
33 lhs_stride: Vec<usize>,
34 rhs_stride: Vec<usize>,
35 out_stride: Vec<usize>,
36 mnk: (usize, usize, usize),
37 },
38}
39
40pub type Result<T> = std::result::Result<T, Error>;
41
42impl Error {
43 pub fn msg<M: Display>(msg: M) -> Self {
45 Self::Msg(msg.to_string()).bt()
46 }
47
48 pub fn bt(self) -> Self {
49 let backtrace = std::backtrace::Backtrace::capture();
50 match backtrace.status() {
51 std::backtrace::BacktraceStatus::Disabled
52 | std::backtrace::BacktraceStatus::Unsupported => self,
53 _ => Self::WithBacktrace {
54 inner: Box::new(self),
55 backtrace: Box::new(backtrace),
56 },
57 }
58 }
59}
60
61impl From<std::io::Error> for Error {
62 fn from(value: std::io::Error) -> Self {
63 Error::IoError(value.to_string())
64 }
65}
66
67#[macro_export]
68macro_rules! bail {
69 ($msg:literal $(,)?) => {
70 return Err($crate::Error::Msg(format!($msg).into()).bt())
71 };
72 ($err:expr $(,)?) => {
73 return Err($crate::Error::Msg(format!($err).into()).bt())
74 };
75 ($fmt:expr, $($arg:tt)*) => {
76 return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()).bt())
77 };
78}
79
80pub(crate) mod private {
81 pub trait Sealed {}
82
83 impl<T, E> Sealed for std::result::Result<T, E> where E: std::error::Error {}
84 impl<T> Sealed for Option<T> {}
85}
86
87pub trait Context<T, E>: private::Sealed {
91 fn context<C>(self, context: C) -> std::result::Result<T, Error>
93 where
94 C: Display + Send + Sync + 'static;
95
96 fn with_context<C, F>(self, f: F) -> std::result::Result<T, Error>
99 where
100 C: Display + Send + Sync + 'static,
101 F: FnOnce() -> C;
102}
103
104impl<T, E> Context<T, E> for std::result::Result<T, E>
105where
106 E: std::error::Error + Send + Sync + 'static,
107{
108 fn context<C>(self, context: C) -> std::result::Result<T, Error>
109 where
110 C: Display + Send + Sync + 'static,
111 {
112 match self {
115 Ok(ok) => Ok(ok),
116 Err(error) => Err(Error::WrappedContext {
117 wrapped: Box::new(error),
118 context: context.to_string(),
119 }),
120 }
121 }
122
123 fn with_context<C, F>(self, context: F) -> std::result::Result<T, Error>
124 where
125 C: Display + Send + Sync + 'static,
126 F: FnOnce() -> C,
127 {
128 match self {
129 Ok(ok) => Ok(ok),
130 Err(error) => Err(Error::WrappedContext {
131 wrapped: Box::new(error),
132 context: context().to_string(),
133 }),
134 }
135 }
136}
137
138impl<T> Context<T, Infallible> for Option<T> {
139 fn context<C>(self, context: C) -> std::result::Result<T, Error>
140 where
141 C: Display + Send + Sync + 'static,
142 {
143 match self {
146 Some(ok) => Ok(ok),
147 None => Err(Error::msg(context)),
148 }
149 }
150
151 fn with_context<C, F>(self, context: F) -> std::result::Result<T, Error>
152 where
153 C: Display + Send + Sync + 'static,
154 F: FnOnce() -> C,
155 {
156 match self {
157 Some(ok) => Ok(ok),
158 None => Err(Error::msg(context())),
159 }
160 }
161}