1use std::str::Utf8Error;
2use crate::{DType, Slice, Shape};
3
4#[derive(Debug, thiserror::Error)]
5pub enum Error {
6 #[error("{msg}, expected: {expected:?}, got: {got:?}")]
8 UnexpectedDType {
9 msg: &'static str,
10 expected: DType,
11 got: DType,
12 },
13
14 #[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
15 DTypeMismatchBinaryOp {
16 lhs: DType,
17 rhs: DType,
18 op: &'static str,
19 },
20
21 #[error("unsupported dtype {0:?} for op {1}")]
22 UnsupportedDTypeForOp(DType, &'static str),
23
24 #[error("Index '{index}' out of range at storage({storage_len}) in take method")]
26 IndexOutOfRangeTake {
27 storage_len: usize,
28 index: usize,
29 },
30
31 #[error("index '{index}' out of range range({max_size}) in {op}")]
32 IndexOutOfRange {
33 max_size: usize,
34 index: usize,
35 op: &'static str,
36 },
37
38 #[error("{op}: dimension index {dim} out of range for shape {shape:?}")]
39 DimOutOfRange {
40 shape: Shape,
41 dim: i32,
42 op: &'static str,
43 },
44
45 #[error("{op}: duplicate dim index {dims:?} for shape {shape:?}")]
46 DuplicateDimIndex {
47 shape: Shape,
48 dims: Vec<usize>,
49 op: &'static str,
50 },
51
52 #[error("try to repeat {repeats} for shape {shape}")]
53 RepeatRankOutOfRange {
54 repeats: Shape,
55 shape: Shape,
56 },
57
58 #[error("unexpected element size in {op}, expected: {expected}, got: {got}")]
60 ElementSizeMismatch {
61 expected: usize,
62 got: usize,
63 op: &'static str
64 },
65
66 #[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
67 UnexpectedNumberOfDims {
68 expected: usize,
69 got: usize,
70 shape: Shape,
71 },
72
73 #[error("{msg}, expected: {expected:?}, got: {got:?}")]
74 UnexpectedShape {
75 msg: String,
76 expected: Shape,
77 got: Shape,
78 },
79
80 #[error("requires contiguous {op}")]
81 RequiresContiguous { op: &'static str },
82
83 #[error("invalid index in {op}")]
84 InvalidIndex {
85 index: usize,
86 size: usize,
87 op: &'static str
88 },
89
90 #[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
91 ShapeMismatchBinaryOp {
92 lhs: Shape,
93 rhs: Shape,
94 op: &'static str,
95 },
96
97 #[error("shape mismatch in cat for dim {dim}, shape for arg 1: {first_shape:?} shape for arg {n}: {nth_shape:?}")]
98 ShapeMismatchCat {
99 dim: usize,
100 first_shape: Shape,
101 n: usize,
102 nth_shape: Shape,
103 },
104
105 #[error("source Tensor shape {src:?} mismatch with condition shape {condition:?}")]
106 ShapeMismatchMaskedSelect {
107 src: Shape,
108 condition: Shape,
109 },
110
111 #[error("mask Tensor shape {mask:?} mismatch with {who} shape")]
112 ShapeMismatchSelect {
113 mask: Shape,
114 who: &'static str,
115 },
116
117 #[error("dst Tensor shape {dst:?} mismatch with src Tensor {src} shape")]
118 ShapeMismatchCopyFrom {
119 dst: Shape,
120 src: Shape,
121 },
122
123 #[error("slice invalid args {msg}: {shape:?}, dim: {dim}, slice: {slice}")]
125 SliceInvalidArgs {
126 shape: Shape,
127 dim: usize,
128 slice: Slice,
129 msg: &'static str,
130 },
131
132 #[error("narrow invalid args {msg}: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
133 NarrowInvalidArgs {
134 shape: Shape,
135 dim: usize,
136 start: usize,
137 len: usize,
138 msg: &'static str,
139 },
140
141 #[error("can squeeze {dim} dim of {shape:?}(not 1)")]
142 SqueezeDimNot1 {
143 shape: Shape,
144 dim: usize,
145 },
146
147 #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
148 BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
149
150 #[error("{op} expects at least one tensor")]
151 OpRequiresAtLeastOneTensor { op: &'static str },
152
153 #[error("rand error because {0}")]
154 Rand(String),
155
156 #[error("Tensor is not a scalar")]
157 NotScalar,
158
159 #[error("len mismatch with lhs {lhs} and rhs {rhs} in {op}")]
161 LenMismatchVector {
162 lhs: usize,
163 rhs: usize,
164 op: &'static str,
165 },
166
167 #[error("shape mismatch with lhs {lhs:?} and rhs {rhs:?} in {op}")]
168 ShapeMismatchMatrix {
169 lhs: (usize, usize),
170 rhs: (usize, usize),
171 op: &'static str,
172 },
173
174 #[error("index {index} of out range in {len} len vector")]
175 VectorIndexOutOfRange {
176 len: usize,
177 index: usize,
178 },
179
180 #[error("{position} index {index} of out range in {len} len matrix")]
181 MatrixIndexOutOfRange {
182 len: usize,
183 index: usize,
184 position: &'static str,
185 },
186
187 #[error("backward not support '{0}'")]
188 BackwardNotSupported(&'static str),
189
190 #[error(transparent)]
192 ParseInt(#[from] std::num::ParseIntError),
193
194 #[error(transparent)]
196 FromUtf8(#[from] std::string::FromUtf8Error),
197
198 #[error(transparent)]
200 Io(#[from] std::io::Error),
201
202 #[error(transparent)]
203 Utf8(#[from] Utf8Error),
204
205 #[error("{0}")]
207 Msg(String),
208
209 #[error("unwrap none")]
210 UnwrapNone,
211}
212
213pub type Result<T> = std::result::Result<T, Error>;
214
215#[macro_export]
216macro_rules! bail {
217 ($msg:literal $(,)?) => {
218 return Err($crate::Error::Msg(format!($msg).into()))?
219 };
220 ($err:expr $(,)?) => {
221 return Err($crate::Error::Msg(format!($err).into()))?
222 };
223 ($fmt:expr, $($arg:tt)*) => {
224 return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()))?
225 };
226}