Skip to main content

lumen_core/
error.rs

1use std::str::Utf8Error;
2use crate::{DType, Slice, Shape};
3
4#[derive(Debug, thiserror::Error)]
5pub enum Error {
6    // === DType Errors ===
7    #[error("{msg}, expected: {expected:?}, got: {got:?}")]
8    UnexpectedDType {
9        msg: &'static str,
10        expected: DType,
11        got: DType,
12    },
13
14    #[error("dtype mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
15    DTypeMismatchBinaryOp {
16        lhs: DType,
17        rhs: DType,
18        op: &'static str,
19    },
20
21    #[error("unsupported dtype {0:?} for op {1}")]
22    UnsupportedDTypeForOp(DType, &'static str),
23
24    // === Dimension Index Errors ===
25    #[error("Index '{index}' out of range at storage({storage_len}) in take method")]
26    IndexOutOfRangeTake {
27        storage_len: usize,
28        index: usize,
29    },
30
31    #[error("index '{index}' out of range range({max_size}) in {op}")]
32    IndexOutOfRange {
33        max_size: usize,
34        index: usize,
35        op: &'static str,
36    },
37
38    #[error("{op}: dimension index {dim} out of range for shape {shape:?}")]
39    DimOutOfRange {
40        shape: Shape,
41        dim: i32,
42        op: &'static str,
43    },
44
45    #[error("{op}: duplicate dim index {dims:?} for shape {shape:?}")]
46    DuplicateDimIndex {
47        shape: Shape,
48        dims: Vec<usize>,
49        op: &'static str,
50    },
51
52    #[error("try to repeat {repeats} for shape {shape}")]
53    RepeatRankOutOfRange {
54        repeats: Shape,
55        shape: Shape,
56    },
57
58    // === Shape Errors ===
59    #[error("unexpected element size in {op}, expected: {expected}, got: {got}")]
60    ElementSizeMismatch {
61        expected: usize,
62        got: usize,
63        op: &'static str
64    },
65
66    #[error("unexpected rank, expected: {expected}, got: {got} ({shape:?})")]
67    UnexpectedNumberOfDims {
68        expected: usize,
69        got: usize,
70        shape: Shape,
71    },
72
73    #[error("{msg}, expected: {expected:?}, got: {got:?}")]
74    UnexpectedShape {
75        msg: String,
76        expected: Shape,
77        got: Shape,
78    },
79
80    #[error("requires contiguous {op}")]
81    RequiresContiguous { op: &'static str },
82
83    #[error("invalid index in {op}")]
84    InvalidIndex {
85        index: usize,
86        size: usize,
87        op: &'static str 
88    },
89
90    #[error("shape mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
91    ShapeMismatchBinaryOp {
92        lhs: Shape,
93        rhs: Shape,
94        op: &'static str,
95    },
96
97    #[error("shape mismatch in cat for dim {dim}, shape for arg 1: {first_shape:?} shape for arg {n}: {nth_shape:?}")]
98    ShapeMismatchCat {
99        dim: usize,
100        first_shape: Shape,
101        n: usize,
102        nth_shape: Shape,
103    },
104
105    #[error("source Tensor shape {src:?} mismatch with condition shape {condition:?}")]
106    ShapeMismatchMaskedSelect {
107        src: Shape,
108        condition: Shape, 
109    },
110
111    #[error("mask Tensor shape {mask:?} mismatch with {who} shape")]
112    ShapeMismatchSelect {
113        mask: Shape,
114        who: &'static str,
115    },
116
117    #[error("dst Tensor shape {dst:?} mismatch with src Tensor {src} shape")]
118    ShapeMismatchCopyFrom {
119        dst: Shape,
120        src: Shape,
121    },
122
123    // === Op Specific Errors ===
124    #[error("slice invalid args {msg}: {shape:?}, dim: {dim}, slice: {slice}")]
125    SliceInvalidArgs {
126        shape: Shape,
127        dim: usize,
128        slice: Slice,
129        msg: &'static str,
130    },
131
132    #[error("narrow invalid args {msg}: {shape:?}, dim: {dim}, start: {start}, len:{len}")]
133    NarrowInvalidArgs {
134        shape: Shape,
135        dim: usize,
136        start: usize,
137        len: usize,
138        msg: &'static str,
139    },
140
141    #[error("can squeeze {dim} dim of {shape:?}(not 1)")]
142    SqueezeDimNot1 {
143        shape: Shape,
144        dim: usize,
145    },
146
147    #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
148    BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
149
150    #[error("{op} expects at least one tensor")]
151    OpRequiresAtLeastOneTensor { op: &'static str },
152
153    #[error("rand error because {0}")]
154    Rand(String),
155
156    #[error("Tensor is not a scalar")]
157    NotScalar,
158
159    // === View ===
160    #[error("len mismatch with lhs {lhs} and rhs {rhs} in {op}")]
161    LenMismatchVector {
162        lhs: usize,
163        rhs: usize,
164        op: &'static str,
165    },
166
167    #[error("shape mismatch with lhs {lhs:?} and rhs {rhs:?} in {op}")]
168    ShapeMismatchMatrix {
169        lhs: (usize, usize),
170        rhs: (usize, usize),
171        op: &'static str,
172    },
173
174    #[error("boolean index should like vector, but got {0}")]
175    BooleanIndexShouldLikeVector(Shape),
176
177    #[error("index {index} of out range in {len} len vector")]
178    VectorIndexOutOfRange {
179        len: usize,
180        index: usize,
181    },
182
183    #[error("{position} index {index} of out range in {len} len matrix")]
184    MatrixIndexOutOfRange {
185        len: usize,
186        index: usize,
187        position: &'static str,
188    },
189
190    #[error("backward not support '{0}'")]
191    BackwardNotSupported(&'static str),
192
193    /// Integer parse error.
194    #[error(transparent)]
195    ParseInt(#[from] std::num::ParseIntError),
196
197    /// Utf8 parse error.
198    #[error(transparent)]
199    FromUtf8(#[from] std::string::FromUtf8Error),
200
201    /// I/O error.
202    #[error(transparent)]
203    Io(#[from] std::io::Error),
204
205    #[error(transparent)]
206    Utf8(#[from] Utf8Error),
207
208    /// Storage error 
209    #[error("visit a meta tensor!")]
210    MetaTensor,
211
212    /// User generated error message
213    #[error("{0}")]
214    Msg(String),
215
216    #[error("unwrap none")]
217    UnwrapNone,
218}
219
220pub type Result<T> = std::result::Result<T, Error>;
221
222#[macro_export]
223macro_rules! bail {
224    ($msg:literal $(,)?) => {
225        return Err($crate::Error::Msg(format!($msg).into()))?
226    };
227    ($err:expr $(,)?) => {
228        return Err($crate::Error::Msg(format!($err).into()))?
229    };
230    ($fmt:expr, $($arg:tt)*) => {
231        return Err($crate::Error::Msg(format!($fmt, $($arg)*).into()))?
232    };
233}