1use std::fmt;
7
8#[derive(Debug, Clone, PartialEq)]
10pub enum KernelError {
11 DimensionMismatch {
13 operation: String,
14 expected: Vec<usize>,
15 actual: Vec<usize>,
16 context: String,
17 },
18
19 InvalidMode {
21 mode: usize,
22 max_mode: usize,
23 context: String,
24 },
25
26 RankMismatch {
28 operation: String,
29 expected_rank: usize,
30 actual_rank: usize,
31 factor_index: usize,
32 },
33
34 EmptyInput {
36 operation: String,
37 parameter: String,
38 },
39
40 InvalidTileSize {
42 operation: String,
43 tile_size: usize,
44 reason: String,
45 },
46
47 IncompatibleShapes {
49 operation: String,
50 shape_a: Vec<usize>,
51 shape_b: Vec<usize>,
52 reason: String,
53 },
54
55 OperationError { operation: String, message: String },
57}
58
59impl fmt::Display for KernelError {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 match self {
62 KernelError::DimensionMismatch {
63 operation,
64 expected,
65 actual,
66 context,
67 } => write!(
68 f,
69 "{}: dimension mismatch - expected {:?}, got {:?}. {}",
70 operation, expected, actual, context
71 ),
72
73 KernelError::InvalidMode {
74 mode,
75 max_mode,
76 context,
77 } => write!(
78 f,
79 "Invalid mode {}: must be < {}. {}",
80 mode, max_mode, context
81 ),
82
83 KernelError::RankMismatch {
84 operation,
85 expected_rank,
86 actual_rank,
87 factor_index,
88 } => write!(
89 f,
90 "{}: rank mismatch at factor {}: expected rank {}, got {}",
91 operation, factor_index, expected_rank, actual_rank
92 ),
93
94 KernelError::EmptyInput {
95 operation,
96 parameter,
97 } => write!(
98 f,
99 "{}: empty input not allowed for parameter '{}'",
100 operation, parameter
101 ),
102
103 KernelError::InvalidTileSize {
104 operation,
105 tile_size,
106 reason,
107 } => write!(
108 f,
109 "{}: invalid tile size {}: {}",
110 operation, tile_size, reason
111 ),
112
113 KernelError::IncompatibleShapes {
114 operation,
115 shape_a,
116 shape_b,
117 reason,
118 } => write!(
119 f,
120 "{}: incompatible shapes {:?} and {:?}: {}",
121 operation, shape_a, shape_b, reason
122 ),
123
124 KernelError::OperationError { operation, message } => {
125 write!(f, "{}: {}", operation, message)
126 }
127 }
128 }
129}
130
131impl std::error::Error for KernelError {}
132
133pub type KernelResult<T> = Result<T, KernelError>;
135
136impl KernelError {
137 pub fn dimension_mismatch(
139 operation: impl Into<String>,
140 expected: Vec<usize>,
141 actual: Vec<usize>,
142 context: impl Into<String>,
143 ) -> Self {
144 KernelError::DimensionMismatch {
145 operation: operation.into(),
146 expected,
147 actual,
148 context: context.into(),
149 }
150 }
151
152 pub fn invalid_mode(mode: usize, max_mode: usize, context: impl Into<String>) -> Self {
154 KernelError::InvalidMode {
155 mode,
156 max_mode,
157 context: context.into(),
158 }
159 }
160
161 pub fn rank_mismatch(
163 operation: impl Into<String>,
164 expected_rank: usize,
165 actual_rank: usize,
166 factor_index: usize,
167 ) -> Self {
168 KernelError::RankMismatch {
169 operation: operation.into(),
170 expected_rank,
171 actual_rank,
172 factor_index,
173 }
174 }
175
176 pub fn empty_input(operation: impl Into<String>, parameter: impl Into<String>) -> Self {
178 KernelError::EmptyInput {
179 operation: operation.into(),
180 parameter: parameter.into(),
181 }
182 }
183
184 pub fn invalid_tile_size(
186 operation: impl Into<String>,
187 tile_size: usize,
188 reason: impl Into<String>,
189 ) -> Self {
190 KernelError::InvalidTileSize {
191 operation: operation.into(),
192 tile_size,
193 reason: reason.into(),
194 }
195 }
196
197 pub fn incompatible_shapes(
199 operation: impl Into<String>,
200 shape_a: Vec<usize>,
201 shape_b: Vec<usize>,
202 reason: impl Into<String>,
203 ) -> Self {
204 KernelError::IncompatibleShapes {
205 operation: operation.into(),
206 shape_a,
207 shape_b,
208 reason: reason.into(),
209 }
210 }
211
212 pub fn operation_error(operation: impl Into<String>, message: impl Into<String>) -> Self {
214 KernelError::OperationError {
215 operation: operation.into(),
216 message: message.into(),
217 }
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn test_dimension_mismatch_display() {
227 let err = KernelError::dimension_mismatch(
228 "khatri_rao",
229 vec![10, 5],
230 vec![10, 3],
231 "Number of columns must match",
232 );
233
234 let msg = format!("{}", err);
235 assert!(msg.contains("khatri_rao"));
236 assert!(msg.contains("dimension mismatch"));
237 assert!(msg.contains("[10, 5]"));
238 assert!(msg.contains("[10, 3]"));
239 }
240
241 #[test]
242 fn test_invalid_mode_display() {
243 let err = KernelError::invalid_mode(3, 3, "Tensor has only 3 modes");
244
245 let msg = format!("{}", err);
246 assert!(msg.contains("Invalid mode 3"));
247 assert!(msg.contains("must be < 3"));
248 }
249
250 #[test]
251 fn test_rank_mismatch_display() {
252 let err = KernelError::rank_mismatch("mttkrp", 5, 3, 2);
253
254 let msg = format!("{}", err);
255 assert!(msg.contains("mttkrp"));
256 assert!(msg.contains("factor 2"));
257 assert!(msg.contains("expected rank 5"));
258 assert!(msg.contains("got 3"));
259 }
260
261 #[test]
262 fn test_empty_input_display() {
263 let err = KernelError::empty_input("outer_product", "vectors");
264
265 let msg = format!("{}", err);
266 assert!(msg.contains("outer_product"));
267 assert!(msg.contains("empty input"));
268 assert!(msg.contains("vectors"));
269 }
270
271 #[test]
272 fn test_invalid_tile_size_display() {
273 let err = KernelError::invalid_tile_size("mttkrp_blocked", 0, "must be positive");
274
275 let msg = format!("{}", err);
276 assert!(msg.contains("mttkrp_blocked"));
277 assert!(msg.contains("invalid tile size 0"));
278 assert!(msg.contains("must be positive"));
279 }
280
281 #[test]
282 fn test_incompatible_shapes_display() {
283 let err = KernelError::incompatible_shapes(
284 "hadamard",
285 vec![2, 3],
286 vec![2, 4],
287 "Element-wise multiplication requires same shape",
288 );
289
290 let msg = format!("{}", err);
291 assert!(msg.contains("hadamard"));
292 assert!(msg.contains("[2, 3]"));
293 assert!(msg.contains("[2, 4]"));
294 assert!(msg.contains("Element-wise multiplication"));
295 }
296}