1use morok_dtype::DType;
2use smallvec::SmallVec;
3use snafu::Snafu;
4
5use crate::{BinaryOp, UnaryOp, shape::Shape};
6
7pub type Result<T, E = Error> = std::result::Result<T, E>;
8
9#[derive(Debug, Clone, PartialEq, Snafu)]
10#[snafu(visibility(pub))]
11pub enum Error {
12 #[snafu(display("dtype mismatch: cannot perform operation on {lhs:?} and {rhs:?}"))]
14 DTypeMismatch { lhs: DType, rhs: DType },
15
16 #[snafu(display("type promotion failed: no common type for {lhs:?} and {rhs:?}"))]
18 TypePromotionFailed { lhs: DType, rhs: DType },
19
20 #[snafu(display("invalid dtype for operation: operation {operation:?}; dtype {dtype:?}"))]
22 InvalidDTypeForUnaryOp { operation: UnaryOp, dtype: DType },
23
24 #[snafu(display("invalid dtype for operation: operation {operation:?}; dtypes {dtypes:?}"))]
26 InvalidDTypeForBinaryOp { operation: BinaryOp, dtypes: SmallVec<[DType; 2]> },
27
28 #[snafu(display("void dtype cannot be used in operations"))]
30 VoidTypeInOp,
31
32 #[snafu(display("index parameter must have Index dtype, got {actual:?}"))]
34 IndexTypeMismatch { actual: DType },
35
36 #[snafu(display("division by zero"))]
38 DivisionByZero,
39
40 #[snafu(display("reshape size mismatch: input size {input_size} != output size {output_size}"))]
42 ReshapeSizeMismatch { input_size: usize, output_size: usize },
43
44 #[snafu(display(
46 "shrink bounds violation: dimension {dim} has range [{begin}, {end}) but shape size is {shape_size}",
47 ))]
48 ShrinkBoundsViolation { dim: usize, begin: usize, end: usize, shape_size: usize },
49
50 #[snafu(display("bind value {value} is outside valid range [{min}, {max}]"))]
52 BindValueOutOfRange { value: i64, min: i64, max: i64 },
53
54 #[snafu(display("index out of bounds"))]
56 IndexOutOfBounds,
57
58 #[snafu(display("expand dimension mismatch: input has {input_dims} dimensions, output has {output_dims}"))]
60 ExpandDimensionMismatch { input_dims: usize, output_dims: usize },
61
62 #[snafu(display(
64 "expand invalid: dimension {dim} has size {input} but needs to expand to {output} (can only expand from 1)",
65 ))]
66 ExpandInvalidDimension { dim: usize, input: usize, output: usize },
67
68 #[snafu(display("invalid permutation {permutation:?}: expected permutation of 0..{expected_dims}"))]
70 PermuteInvalidPermutation { permutation: Vec<usize>, expected_dims: usize },
71
72 #[snafu(display(
74 "pad has negative value: dimension {dim} has padding ({begin}, {end}) but padding must be non-negative",
75 ))]
76 PadNegativeValue { dim: usize, begin: isize, end: isize },
77
78 #[snafu(display("pad dimension mismatch: padding has {padding_dims} dimensions but shape has {shape_dims}"))]
80 PadDimensionMismatch { padding_dims: usize, shape_dims: usize },
81
82 #[snafu(display("flip specification invalid: expected {expected_dims} dimensions, got {got_dims}"))]
84 FlipInvalidSpec { expected_dims: usize, got_dims: usize },
85
86 #[snafu(display("reduce axis {axis} is invalid for shape with {shape_dims} dimensions"))]
88 ReduceAxisInvalid { axis: i32, shape_dims: usize },
89
90 #[snafu(display("shape mismatch: cannot perform elementwise operation on shapes {lhs_shape:?} and {rhs_shape:?}"))]
92 ShapeMismatch { lhs_shape: Vec<usize>, rhs_shape: Vec<usize> },
93
94 #[snafu(display("Shape mismatch in {op:?}: {lhs:?} vs {rhs:?}"))]
96 BinaryShapeMismatch { op: crate::types::BinaryOp, lhs: Box<Shape>, rhs: Box<Shape> },
97
98 #[snafu(display("reshape contains negative dimension in {shape:?}"))]
100 ReshapeNegativeDimension { shape: SmallVec<[isize; 4]> },
101
102 #[snafu(display("cannot broadcast shapes {lhs:?} and {rhs:?}"))]
104 BroadcastShapeMismatch { lhs: Box<Shape>, rhs: Box<Shape> },
105
106 #[snafu(display("symbolic padding is not supported: padding dimensions must be concrete values"))]
108 SymbolicPaddingUnsupported,
109
110 #[snafu(display("symbolic shrinking is not supported: shrink ranges must be concrete values"))]
112 SymbolicShrinkingUnsupported,
113
114 #[snafu(display("symbolic shape is not supported for {operation}: shape dimensions must be concrete values"))]
116 SymbolicShapeUnsupported { operation: String },
117
118 #[snafu(display("cannot allocate buffer with symbolic size: range bound resolved to {bound:?}"))]
120 SymbolicBufferSize { bound: crate::ConstValue },
121
122 #[snafu(display(
124 "ternary operation branches have mismatched shapes: true branch {true_branch:?} vs false branch {false_branch:?}"
125 ))]
126 TernaryBranchShapeMismatch { true_branch: Box<Shape>, false_branch: Box<Shape> },
127
128 #[snafu(display(
130 "{op} must have Ptr dtype (following Tinygrad spec), got {dtype:?}. Use DefineVar for scalar variables."
131 ))]
132 BufferDefRequiresPtrDType { op: &'static str, dtype: DType },
133
134 #[snafu(display("VECTORIZE requires at least one element"))]
139 VectorizeEmpty,
140
141 #[snafu(display("VECTORIZE elements have mismatched dtypes: expected {expected:?}, got {actual:?}"))]
143 VectorizeDTypeMismatch { expected: DType, actual: DType },
144
145 #[snafu(display("GEP index {index} out of bounds for vector with {vcount} elements"))]
147 GepIndexOutOfBounds { index: usize, vcount: usize },
148
149 #[snafu(display("GEP requires vector source (vcount > 1), got {dtype:?}"))]
151 GepRequiresVector { dtype: DType },
152
153 #[snafu(display("CONTRACT dtype count {dtype_count} != axis product {axis_product}"))]
155 ContractCountMismatch { dtype_count: usize, axis_product: usize },
156
157 #[snafu(display("UNROLL src dtype count {dtype_count} != axis product {axis_product}"))]
159 UnrollCountMismatch { dtype_count: usize, axis_product: usize },
160
161 #[snafu(display("WHERE condition must be bool, got {actual:?}"))]
163 WhereConditionNotBool { actual: DType },
164
165 #[snafu(display("BROADCAST requires scalar source (vcount=1), got {dtype:?}"))]
167 BroadcastRequiresScalar { dtype: DType },
168
169 #[snafu(display(
171 "MulAcc operands must have matching dtypes (including vcount): a={a_dtype:?}, b={b_dtype:?}, c={c_dtype:?}"
172 ))]
173 MulAccDtypeMismatch { a_dtype: DType, b_dtype: DType, c_dtype: DType },
174}
175
176pub fn log_provenance(uop_id: u64, error: &Error) {
182 use crate::provenance::{PROVENANCE_TRACKER, format_chain};
183
184 PROVENANCE_TRACKER.with(|tracker| {
185 let chain = tracker.borrow().get_chain(uop_id);
186 if !chain.is_empty() {
187 tracing::error!(
188 uop.id = uop_id,
189 error = %error,
190 provenance_chain = %format_chain(&chain),
191 "uop error with provenance"
192 );
193 }
194 });
195}