1#![allow(dead_code)]
7
8use thiserror::Error;
9
10#[derive(Debug, Error)]
12pub enum Error {
13 #[error("{0}")]
15 General(String),
16
17 #[error(
19 "buffer allocation failed: requested {requested_bytes} bytes (available {available_bytes})"
20 )]
21 BufferAllocationFailed {
22 requested_bytes: usize,
24 available_bytes: usize,
26 },
27
28 #[error("invalid buffer handle: {0}")]
30 InvalidBufferHandle(usize),
31
32 #[error("shader compilation error in '{shader}': {message}")]
34 ShaderCompilationError {
35 shader: String,
37 message: String,
39 },
40
41 #[error("dispatch size {dispatch_size} exceeds hardware limit {limit}")]
43 DispatchLimitExceeded {
44 dispatch_size: usize,
46 limit: usize,
48 },
49
50 #[error("grid index ({i}, {j}, {k}) out of bounds for grid ({nx}, {ny}, {nz})")]
52 GridIndexOutOfBounds {
53 i: usize,
55 j: usize,
57 k: usize,
59 nx: usize,
61 ny: usize,
63 nz: usize,
65 },
66
67 #[error("kernel '{kernel}' expects {expected} arguments but got {got}")]
69 KernelArgCountMismatch {
70 kernel: String,
72 expected: usize,
74 got: usize,
76 },
77
78 #[error("unsupported feature: {feature}")]
80 UnsupportedFeature {
81 feature: String,
83 },
84}
85
86#[derive(Debug, Error)]
88#[error("pipeline stage '{stage}' failed: {source}")]
89pub struct PipelineStageError {
90 pub stage: String,
92 pub source: Box<Error>,
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
98pub enum ErrorSeverity {
99 Info,
101 Warning,
103 Fatal,
105}
106
107impl std::fmt::Display for ErrorSeverity {
108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109 match self {
110 ErrorSeverity::Info => write!(f, "INFO"),
111 ErrorSeverity::Warning => write!(f, "WARNING"),
112 ErrorSeverity::Fatal => write!(f, "FATAL"),
113 }
114 }
115}
116
117#[derive(Debug)]
119pub struct AnnotatedError {
120 pub error: Error,
122 pub severity: ErrorSeverity,
124 pub kernel: Option<String>,
126}
127
128impl std::fmt::Display for AnnotatedError {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 if let Some(ref k) = self.kernel {
131 write!(f, "[{}] kernel '{}': {}", self.severity, k, self.error)
132 } else {
133 write!(f, "[{}] {}", self.severity, self.error)
134 }
135 }
136}
137
138impl AnnotatedError {
139 pub fn fatal(error: Error, kernel: Option<&str>) -> Self {
141 Self {
142 error,
143 severity: ErrorSeverity::Fatal,
144 kernel: kernel.map(str::to_string),
145 }
146 }
147
148 pub fn warning(error: Error, kernel: Option<&str>) -> Self {
150 Self {
151 error,
152 severity: ErrorSeverity::Warning,
153 kernel: kernel.map(str::to_string),
154 }
155 }
156}
157
158pub type Result<T> = std::result::Result<T, Error>;
160
161#[derive(Debug, Error)]
163pub enum GpuError {
164 #[error("GPU backend init failed: {0}")]
166 BackendInit(String),
167
168 #[error("shader dispatch error: {0}")]
170 ShaderDispatch(String),
171
172 #[error("GPU buffer read-back failed: {0}")]
174 ReadBack(String),
175
176 #[error("invalid GPU buffer handle: {0}")]
178 InvalidHandle(usize),
179}
180
181impl Error {
182 pub fn general(msg: impl std::fmt::Display) -> Self {
184 Error::General(msg.to_string())
185 }
186
187 pub fn is_allocation_error(&self) -> bool {
189 matches!(self, Error::BufferAllocationFailed { .. })
190 }
191
192 pub fn is_shader_error(&self) -> bool {
194 matches!(self, Error::ShaderCompilationError { .. })
195 }
196
197 pub fn is_grid_error(&self) -> bool {
199 matches!(self, Error::GridIndexOutOfBounds { .. })
200 }
201
202 pub fn is_arg_mismatch(&self) -> bool {
204 matches!(self, Error::KernelArgCountMismatch { .. })
205 }
206
207 pub fn is_unsupported(&self) -> bool {
209 matches!(self, Error::UnsupportedFeature { .. })
210 }
211
212 pub fn in_stage(self, stage: impl Into<String>) -> PipelineStageError {
214 PipelineStageError {
215 stage: stage.into(),
216 source: Box::new(self),
217 }
218 }
219
220 pub fn fatal(self, kernel: Option<&str>) -> AnnotatedError {
222 AnnotatedError::fatal(self, kernel)
223 }
224
225 pub fn warning(self, kernel: Option<&str>) -> AnnotatedError {
227 AnnotatedError::warning(self, kernel)
228 }
229
230 pub fn into_err<T>(self) -> Result<T> {
232 Err(self)
233 }
234}
235
236pub fn alloc_err(requested_bytes: usize, available_bytes: usize) -> Error {
240 Error::BufferAllocationFailed {
241 requested_bytes,
242 available_bytes,
243 }
244}
245
246pub fn arg_mismatch_err(kernel: impl Into<String>, expected: usize, got: usize) -> Error {
248 Error::KernelArgCountMismatch {
249 kernel: kernel.into(),
250 expected,
251 got,
252 }
253}
254
255#[allow(clippy::too_many_arguments)]
257pub fn grid_oob_err(i: usize, j: usize, k: usize, nx: usize, ny: usize, nz: usize) -> Error {
258 Error::GridIndexOutOfBounds {
259 i,
260 j,
261 k,
262 nx,
263 ny,
264 nz,
265 }
266}
267
268pub fn dispatch_limit_err(dispatch_size: usize, limit: usize) -> Error {
270 Error::DispatchLimitExceeded {
271 dispatch_size,
272 limit,
273 }
274}
275
276pub fn shader_err(shader: impl Into<String>, message: impl Into<String>) -> Error {
278 Error::ShaderCompilationError {
279 shader: shader.into(),
280 message: message.into(),
281 }
282}
283
284pub fn unsupported_err(feature: impl Into<String>) -> Error {
286 Error::UnsupportedFeature {
287 feature: feature.into(),
288 }
289}
290
291pub fn collect_errors(errors: Vec<Error>) -> Result<()> {
296 errors.into_iter().next().map_or(Ok(()), Err)
297}
298
299pub fn check(condition: bool, msg: impl std::fmt::Display) -> Result<()> {
301 if condition {
302 Ok(())
303 } else {
304 Err(Error::general(msg))
305 }
306}
307
308#[cfg(test)]
309mod error_tests {
310 use super::*;
311
312 #[test]
313 fn test_general_error_message() {
314 let e = Error::general("something went wrong");
315 assert_eq!(e.to_string(), "something went wrong");
316 }
317
318 #[test]
319 fn test_buffer_allocation_failed_message() {
320 let e = Error::BufferAllocationFailed {
321 requested_bytes: 1024,
322 available_bytes: 512,
323 };
324 let msg = e.to_string();
325 assert!(msg.contains("1024"), "should mention requested bytes");
326 assert!(msg.contains("512"), "should mention available bytes");
327 assert!(e.is_allocation_error());
328 }
329
330 #[test]
331 fn test_invalid_buffer_handle() {
332 let e = Error::InvalidBufferHandle(42);
333 assert!(e.to_string().contains("42"));
334 }
335
336 #[test]
337 fn test_shader_compilation_error() {
338 let e = Error::ShaderCompilationError {
339 shader: "sph_density".to_string(),
340 message: "undefined symbol".to_string(),
341 };
342 let msg = e.to_string();
343 assert!(msg.contains("sph_density"));
344 assert!(msg.contains("undefined symbol"));
345 assert!(e.is_shader_error());
346 }
347
348 #[test]
349 fn test_dispatch_limit_exceeded() {
350 let e = Error::DispatchLimitExceeded {
351 dispatch_size: 100_000,
352 limit: 65535,
353 };
354 let msg = e.to_string();
355 assert!(msg.contains("100000"));
356 assert!(msg.contains("65535"));
357 }
358
359 #[test]
360 fn test_grid_index_out_of_bounds() {
361 let e = Error::GridIndexOutOfBounds {
362 i: 10,
363 j: 5,
364 k: 3,
365 nx: 8,
366 ny: 8,
367 nz: 8,
368 };
369 let msg = e.to_string();
370 assert!(msg.contains("10"));
371 assert!(msg.contains('8'.to_string().as_str()));
372 }
373
374 #[test]
375 fn test_is_not_shader_error() {
376 let e = Error::general("not a shader error");
377 assert!(!e.is_shader_error());
378 }
379
380 #[test]
381 fn test_unsupported_feature() {
382 let e = Error::UnsupportedFeature {
383 feature: "ray_tracing".to_string(),
384 };
385 assert!(e.to_string().contains("ray_tracing"));
386 }
387
388 #[test]
391 fn test_is_grid_error() {
392 let e = grid_oob_err(1, 2, 3, 4, 5, 6);
393 assert!(e.is_grid_error());
394 assert!(!e.is_allocation_error());
395 }
396
397 #[test]
398 fn test_is_arg_mismatch() {
399 let e = arg_mismatch_err("test_kernel", 3, 2);
400 assert!(e.is_arg_mismatch());
401 assert!(!e.is_shader_error());
402 }
403
404 #[test]
405 fn test_is_unsupported() {
406 let e = unsupported_err("ray_tracing");
407 assert!(e.is_unsupported());
408 }
409
410 #[test]
411 fn test_in_stage_wraps_error() {
412 let e = Error::general("boom");
413 let wrapped = e.in_stage("sph_density");
414 assert!(wrapped.to_string().contains("sph_density"));
415 assert!(wrapped.to_string().contains("boom"));
416 }
417
418 #[test]
419 fn test_alloc_err_convenience() {
420 let e = alloc_err(512, 256);
421 assert!(e.is_allocation_error());
422 assert!(e.to_string().contains("512"));
423 }
424
425 #[test]
426 fn test_dispatch_limit_err_convenience() {
427 let e = dispatch_limit_err(99999, 65535);
428 assert!(e.to_string().contains("99999"));
429 }
430
431 #[test]
432 fn test_shader_err_convenience() {
433 let e = shader_err("my_shader", "syntax error");
434 assert!(e.is_shader_error());
435 assert!(e.to_string().contains("syntax error"));
436 }
437
438 #[test]
439 fn test_into_err() {
440 let result: Result<i32> = Error::general("nope").into_err();
441 assert!(result.is_err());
442 }
443
444 #[test]
445 fn test_collect_errors_empty() {
446 assert!(collect_errors(vec![]).is_ok());
447 }
448
449 #[test]
450 fn test_collect_errors_nonempty() {
451 let errs = vec![Error::general("first"), Error::general("second")];
452 let result = collect_errors(errs);
453 assert!(result.is_err());
454 assert!(result.unwrap_err().to_string().contains("first"));
455 }
456
457 #[test]
458 fn test_check_passes() {
459 assert!(check(true, "should not fail").is_ok());
460 }
461
462 #[test]
463 fn test_check_fails() {
464 let r = check(false, "condition violated");
465 assert!(r.is_err());
466 assert!(r.unwrap_err().to_string().contains("condition violated"));
467 }
468
469 #[test]
470 fn test_annotated_error_fatal_display() {
471 let e = Error::general("crash");
472 let ann = e.fatal(Some("sph_kernel"));
473 let s = ann.to_string();
474 assert!(s.contains("FATAL"));
475 assert!(s.contains("sph_kernel"));
476 assert!(s.contains("crash"));
477 }
478
479 #[test]
480 fn test_annotated_error_warning_no_kernel() {
481 let e = Error::general("degraded");
482 let ann = e.warning(None);
483 let s = ann.to_string();
484 assert!(s.contains("WARNING"));
485 assert!(s.contains("degraded"));
486 }
487
488 #[test]
489 fn test_error_severity_ordering() {
490 assert!(ErrorSeverity::Info < ErrorSeverity::Warning);
491 assert!(ErrorSeverity::Warning < ErrorSeverity::Fatal);
492 }
493
494 #[test]
495 fn test_error_severity_display() {
496 assert_eq!(ErrorSeverity::Info.to_string(), "INFO");
497 assert_eq!(ErrorSeverity::Warning.to_string(), "WARNING");
498 assert_eq!(ErrorSeverity::Fatal.to_string(), "FATAL");
499 }
500}