Skip to main content

morok_ir/
error.rs

1use morok_dtype::DType;
2use smallvec::SmallVec;
3use snafu::Snafu;
4
5use crate::{BinaryOp, UnaryOp, shape::Shape};
6
7pub type Result<T, E = Error> = std::result::Result<T, E>;
8
9#[derive(Debug, Clone, PartialEq, Snafu)]
10#[snafu(visibility(pub))]
11pub enum Error {
12    /// DType mismatch in binary operation.
13    #[snafu(display("dtype mismatch: cannot perform operation on {lhs:?} and {rhs:?}"))]
14    DTypeMismatch { lhs: DType, rhs: DType },
15
16    /// Type promotion failed - no common type.
17    #[snafu(display("type promotion failed: no common type for {lhs:?} and {rhs:?}"))]
18    TypePromotionFailed { lhs: DType, rhs: DType },
19
20    /// Invalid dtype for operation (e.g., bitwise on float).
21    #[snafu(display("invalid dtype for operation: operation {operation:?}; dtype {dtype:?}"))]
22    InvalidDTypeForUnaryOp { operation: UnaryOp, dtype: DType },
23
24    /// Invalid dtype for operation (e.g., bitwise on float).
25    #[snafu(display("invalid dtype for operation: operation {operation:?}; dtypes {dtypes:?}"))]
26    InvalidDTypeForBinaryOp { operation: BinaryOp, dtypes: SmallVec<[DType; 2]> },
27
28    /// Void dtype cannot be used in operations.
29    #[snafu(display("void dtype cannot be used in operations"))]
30    VoidTypeInOp,
31
32    /// Index parameter must have Index dtype.
33    #[snafu(display("index parameter must have Index dtype, got {actual:?}"))]
34    IndexTypeMismatch { actual: DType },
35
36    /// Division by zero.
37    #[snafu(display("division by zero"))]
38    DivisionByZero,
39
40    /// Reshape size mismatch.
41    #[snafu(display("reshape size mismatch: input size {input_size} != output size {output_size}"))]
42    ReshapeSizeMismatch { input_size: usize, output_size: usize },
43
44    /// Shrink bounds violation.
45    #[snafu(display(
46        "shrink bounds violation: dimension {dim} has range [{begin}, {end}) but shape size is {shape_size}",
47    ))]
48    ShrinkBoundsViolation { dim: usize, begin: usize, end: usize, shape_size: usize },
49
50    /// Bind value out of range.
51    #[snafu(display("bind value {value} is outside valid range [{min}, {max}]"))]
52    BindValueOutOfRange { value: i64, min: i64, max: i64 },
53
54    /// Index out of bounds.
55    #[snafu(display("index out of bounds"))]
56    IndexOutOfBounds,
57
58    /// Expand dimension count mismatch.
59    #[snafu(display("expand dimension mismatch: input has {input_dims} dimensions, output has {output_dims}"))]
60    ExpandDimensionMismatch { input_dims: usize, output_dims: usize },
61
62    /// Expand invalid dimension (can only expand dimensions of size 1).
63    #[snafu(display(
64        "expand invalid: dimension {dim} has size {input} but needs to expand to {output} (can only expand from 1)",
65    ))]
66    ExpandInvalidDimension { dim: usize, input: usize, output: usize },
67
68    /// Permute has invalid permutation.
69    #[snafu(display("invalid permutation {permutation:?}: expected permutation of 0..{expected_dims}"))]
70    PermuteInvalidPermutation { permutation: Vec<usize>, expected_dims: usize },
71
72    /// Pad has negative padding value.
73    #[snafu(display(
74        "pad has negative value: dimension {dim} has padding ({begin}, {end}) but padding must be non-negative",
75    ))]
76    PadNegativeValue { dim: usize, begin: isize, end: isize },
77
78    /// Pad dimension count mismatch.
79    #[snafu(display("pad dimension mismatch: padding has {padding_dims} dimensions but shape has {shape_dims}"))]
80    PadDimensionMismatch { padding_dims: usize, shape_dims: usize },
81
82    /// Flip specification invalid.
83    #[snafu(display("flip specification invalid: expected {expected_dims} dimensions, got {got_dims}"))]
84    FlipInvalidSpec { expected_dims: usize, got_dims: usize },
85
86    /// Reduce axis invalid.
87    #[snafu(display("reduce axis {axis} is invalid for shape with {shape_dims} dimensions"))]
88    ReduceAxisInvalid { axis: i32, shape_dims: usize },
89
90    /// Shape mismatch in elementwise operation.
91    #[snafu(display("shape mismatch: cannot perform elementwise operation on shapes {lhs_shape:?} and {rhs_shape:?}"))]
92    ShapeMismatch { lhs_shape: Vec<usize>, rhs_shape: Vec<usize> },
93
94    /// Shape mismatch in binary operation.
95    #[snafu(display("Shape mismatch in {op:?}: {lhs:?} vs {rhs:?}"))]
96    BinaryShapeMismatch { op: crate::types::BinaryOp, lhs: Box<Shape>, rhs: Box<Shape> },
97
98    /// Reshape contains negative dimension.
99    #[snafu(display("reshape contains negative dimension in {shape:?}"))]
100    ReshapeNegativeDimension { shape: SmallVec<[isize; 4]> },
101
102    /// Broadcasting shape mismatch.
103    #[snafu(display("cannot broadcast shapes {lhs:?} and {rhs:?}"))]
104    BroadcastShapeMismatch { lhs: Box<Shape>, rhs: Box<Shape> },
105
106    /// Symbolic padding unsupported.
107    #[snafu(display("symbolic padding is not supported: padding dimensions must be concrete values"))]
108    SymbolicPaddingUnsupported,
109
110    /// Symbolic shrinking unsupported.
111    #[snafu(display("symbolic shrinking is not supported: shrink ranges must be concrete values"))]
112    SymbolicShrinkingUnsupported,
113
114    /// Symbolic shape unsupported.
115    #[snafu(display("symbolic shape is not supported for {operation}: shape dimensions must be concrete values"))]
116    SymbolicShapeUnsupported { operation: String },
117
118    /// Symbolic buffer size unsupported.
119    #[snafu(display("cannot allocate buffer with symbolic size: range bound resolved to {bound:?}"))]
120    SymbolicBufferSize { bound: crate::ConstValue },
121
122    /// Ternary branch shape mismatch.
123    #[snafu(display(
124        "ternary operation branches have mismatched shapes: true branch {true_branch:?} vs false branch {false_branch:?}"
125    ))]
126    TernaryBranchShapeMismatch { true_branch: Box<Shape>, false_branch: Box<Shape> },
127
128    /// DefineLocal must have Ptr dtype.
129    #[snafu(display(
130        "{op} must have Ptr dtype (following Tinygrad spec), got {dtype:?}. Use DefineVar for scalar variables."
131    ))]
132    BufferDefRequiresPtrDType { op: &'static str, dtype: DType },
133
134    // =========================================================================
135    // UOp Builder Guards (user-facing API for kernel implementation)
136    // =========================================================================
137    /// VECTORIZE requires at least one element.
138    #[snafu(display("VECTORIZE requires at least one element"))]
139    VectorizeEmpty,
140
141    /// VECTORIZE elements have mismatched dtypes.
142    #[snafu(display("VECTORIZE elements have mismatched dtypes: expected {expected:?}, got {actual:?}"))]
143    VectorizeDTypeMismatch { expected: DType, actual: DType },
144
145    /// GEP index out of bounds.
146    #[snafu(display("GEP index {index} out of bounds for vector with {vcount} elements"))]
147    GepIndexOutOfBounds { index: usize, vcount: usize },
148
149    /// GEP requires vector source.
150    #[snafu(display("GEP requires vector source (vcount > 1), got {dtype:?}"))]
151    GepRequiresVector { dtype: DType },
152
153    /// CONTRACT dtype count != axis product.
154    #[snafu(display("CONTRACT dtype count {dtype_count} != axis product {axis_product}"))]
155    ContractCountMismatch { dtype_count: usize, axis_product: usize },
156
157    /// UNROLL src dtype count != axis product.
158    #[snafu(display("UNROLL src dtype count {dtype_count} != axis product {axis_product}"))]
159    UnrollCountMismatch { dtype_count: usize, axis_product: usize },
160
161    /// WHERE condition must be bool.
162    #[snafu(display("WHERE condition must be bool, got {actual:?}"))]
163    WhereConditionNotBool { actual: DType },
164
165    /// BROADCAST requires scalar source.
166    #[snafu(display("BROADCAST requires scalar source (vcount=1), got {dtype:?}"))]
167    BroadcastRequiresScalar { dtype: DType },
168
169    /// MulAcc operands must have matching dtypes.
170    #[snafu(display(
171        "MulAcc operands must have matching dtypes (including vcount): a={a_dtype:?}, b={b_dtype:?}, c={c_dtype:?}"
172    ))]
173    MulAccDtypeMismatch { a_dtype: DType, b_dtype: DType, c_dtype: DType },
174}
175
176/// Enhance an error with provenance information for a UOp.
177///
178/// This function retrieves the provenance chain for a UOp and logs it,
179/// providing detailed debugging information about the operation's origin and
180/// transformation history.
181pub fn log_provenance(uop_id: u64, error: &Error) {
182    use crate::provenance::{PROVENANCE_TRACKER, format_chain};
183
184    PROVENANCE_TRACKER.with(|tracker| {
185        let chain = tracker.borrow().get_chain(uop_id);
186        if !chain.is_empty() {
187            tracing::error!(
188                uop.id = uop_id,
189                error = %error,
190                provenance_chain = %format_chain(&chain),
191                "uop error with provenance"
192            );
193        }
194    });
195}