Skip to main content

oxiphysics_gpu/
error.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Error types for oxiphysics-gpu
5
6#![allow(dead_code)]
7
8use thiserror::Error;
9
10/// Main error type for the gpu module.
11#[derive(Debug, Error)]
12pub enum Error {
13    /// Generic error with a free-form message.
14    #[error("{0}")]
15    General(String),
16
17    /// A GPU buffer allocation failed.
18    #[error(
19        "buffer allocation failed: requested {requested_bytes} bytes (available {available_bytes})"
20    )]
21    BufferAllocationFailed {
22        /// Bytes requested.
23        requested_bytes: usize,
24        /// Bytes actually available.
25        available_bytes: usize,
26    },
27
28    /// An invalid buffer handle was used.
29    #[error("invalid buffer handle: {0}")]
30    InvalidBufferHandle(usize),
31
32    /// A shader compilation error (mock).
33    #[error("shader compilation error in '{shader}': {message}")]
34    ShaderCompilationError {
35        /// Name of the offending shader.
36        shader: String,
37        /// Compiler message.
38        message: String,
39    },
40
41    /// A dispatch exceeded the hardware work-group limit.
42    #[error("dispatch size {dispatch_size} exceeds hardware limit {limit}")]
43    DispatchLimitExceeded {
44        /// Requested dispatch size (number of work-groups).
45        dispatch_size: usize,
46        /// Hardware maximum.
47        limit: usize,
48    },
49
50    /// Out-of-bounds grid access.
51    #[error("grid index ({i}, {j}, {k}) out of bounds for grid ({nx}, {ny}, {nz})")]
52    GridIndexOutOfBounds {
53        /// Requested x index.
54        i: usize,
55        /// Requested y index.
56        j: usize,
57        /// Requested z index.
58        k: usize,
59        /// Grid x dimension.
60        nx: usize,
61        /// Grid y dimension.
62        ny: usize,
63        /// Grid z dimension.
64        nz: usize,
65    },
66
67    /// A kernel argument count mismatch.
68    #[error("kernel '{kernel}' expects {expected} arguments but got {got}")]
69    KernelArgCountMismatch {
70        /// Kernel name.
71        kernel: String,
72        /// Expected number of arguments.
73        expected: usize,
74        /// Provided number of arguments.
75        got: usize,
76    },
77
78    /// An unsupported backend feature was requested.
79    #[error("unsupported feature: {feature}")]
80    UnsupportedFeature {
81        /// Description of the unsupported feature.
82        feature: String,
83    },
84}
85
86/// A pipeline-stage error: carries the stage name plus the underlying cause.
87#[derive(Debug, Error)]
88#[error("pipeline stage '{stage}' failed: {source}")]
89pub struct PipelineStageError {
90    /// Name of the pipeline stage (e.g. `"vertex_fetch"`, `"sph_density"`).
91    pub stage: String,
92    /// Root cause.
93    pub source: Box<Error>,
94}
95
96/// Severity level for a GPU error.
97#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
98pub enum ErrorSeverity {
99    /// Informational — execution can continue.
100    Info,
101    /// Warning — partial results may be degraded.
102    Warning,
103    /// Fatal — must abort current dispatch.
104    Fatal,
105}
106
107impl std::fmt::Display for ErrorSeverity {
108    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109        match self {
110            ErrorSeverity::Info => write!(f, "INFO"),
111            ErrorSeverity::Warning => write!(f, "WARNING"),
112            ErrorSeverity::Fatal => write!(f, "FATAL"),
113        }
114    }
115}
116
117/// An error annotated with severity and an optional kernel name.
118#[derive(Debug)]
119pub struct AnnotatedError {
120    /// The underlying error.
121    pub error: Error,
122    /// Severity classification.
123    pub severity: ErrorSeverity,
124    /// Optional kernel that triggered the error.
125    pub kernel: Option<String>,
126}
127
128impl std::fmt::Display for AnnotatedError {
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        if let Some(ref k) = self.kernel {
131            write!(f, "[{}] kernel '{}': {}", self.severity, k, self.error)
132        } else {
133            write!(f, "[{}] {}", self.severity, self.error)
134        }
135    }
136}
137
138impl AnnotatedError {
139    /// Wrap an error as fatal with an optional kernel label.
140    pub fn fatal(error: Error, kernel: Option<&str>) -> Self {
141        Self {
142            error,
143            severity: ErrorSeverity::Fatal,
144            kernel: kernel.map(str::to_string),
145        }
146    }
147
148    /// Wrap an error as a warning.
149    pub fn warning(error: Error, kernel: Option<&str>) -> Self {
150        Self {
151            error,
152            severity: ErrorSeverity::Warning,
153            kernel: kernel.map(str::to_string),
154        }
155    }
156}
157
158/// Result type alias
159pub type Result<T> = std::result::Result<T, Error>;
160
161impl Error {
162    /// Construct a [`Error::General`] from any `Display`-able value.
163    pub fn general(msg: impl std::fmt::Display) -> Self {
164        Error::General(msg.to_string())
165    }
166
167    /// True when this is a recoverable allocation error.
168    pub fn is_allocation_error(&self) -> bool {
169        matches!(self, Error::BufferAllocationFailed { .. })
170    }
171
172    /// True when this is a shader compilation error.
173    pub fn is_shader_error(&self) -> bool {
174        matches!(self, Error::ShaderCompilationError { .. })
175    }
176
177    /// True when this is an out-of-bounds grid error.
178    pub fn is_grid_error(&self) -> bool {
179        matches!(self, Error::GridIndexOutOfBounds { .. })
180    }
181
182    /// True when this is a kernel argument mismatch error.
183    pub fn is_arg_mismatch(&self) -> bool {
184        matches!(self, Error::KernelArgCountMismatch { .. })
185    }
186
187    /// True when this is an unsupported feature error.
188    pub fn is_unsupported(&self) -> bool {
189        matches!(self, Error::UnsupportedFeature { .. })
190    }
191
192    /// Wrap this error in a [`PipelineStageError`].
193    pub fn in_stage(self, stage: impl Into<String>) -> PipelineStageError {
194        PipelineStageError {
195            stage: stage.into(),
196            source: Box::new(self),
197        }
198    }
199
200    /// Annotate with fatal severity.
201    pub fn fatal(self, kernel: Option<&str>) -> AnnotatedError {
202        AnnotatedError::fatal(self, kernel)
203    }
204
205    /// Annotate with warning severity.
206    pub fn warning(self, kernel: Option<&str>) -> AnnotatedError {
207        AnnotatedError::warning(self, kernel)
208    }
209
210    /// Convert into a `Result::Err`.
211    pub fn into_err<T>(self) -> Result<T> {
212        Err(self)
213    }
214}
215
216// ── Convenience constructors ─────────────────────────────────────────────────
217
218/// Build a [`Error::BufferAllocationFailed`] error.
219pub fn alloc_err(requested_bytes: usize, available_bytes: usize) -> Error {
220    Error::BufferAllocationFailed {
221        requested_bytes,
222        available_bytes,
223    }
224}
225
226/// Build a [`Error::KernelArgCountMismatch`] error.
227pub fn arg_mismatch_err(kernel: impl Into<String>, expected: usize, got: usize) -> Error {
228    Error::KernelArgCountMismatch {
229        kernel: kernel.into(),
230        expected,
231        got,
232    }
233}
234
235/// Build a [`Error::GridIndexOutOfBounds`] error.
236#[allow(clippy::too_many_arguments)]
237pub fn grid_oob_err(i: usize, j: usize, k: usize, nx: usize, ny: usize, nz: usize) -> Error {
238    Error::GridIndexOutOfBounds {
239        i,
240        j,
241        k,
242        nx,
243        ny,
244        nz,
245    }
246}
247
248/// Build a [`Error::DispatchLimitExceeded`] error.
249pub fn dispatch_limit_err(dispatch_size: usize, limit: usize) -> Error {
250    Error::DispatchLimitExceeded {
251        dispatch_size,
252        limit,
253    }
254}
255
256/// Build a [`Error::ShaderCompilationError`].
257pub fn shader_err(shader: impl Into<String>, message: impl Into<String>) -> Error {
258    Error::ShaderCompilationError {
259        shader: shader.into(),
260        message: message.into(),
261    }
262}
263
264/// Build an [`Error::UnsupportedFeature`].
265pub fn unsupported_err(feature: impl Into<String>) -> Error {
266    Error::UnsupportedFeature {
267        feature: feature.into(),
268    }
269}
270
271// ── Error collection ─────────────────────────────────────────────────────────
272
273/// Collect multiple errors from a batch dispatch.  Returns `Ok(())` if the
274/// vec is empty, or `Err` containing the first error otherwise.
275pub fn collect_errors(errors: Vec<Error>) -> Result<()> {
276    errors.into_iter().next().map_or(Ok(()), Err)
277}
278
279/// Check a boolean condition; return `Err(Error::General(msg))` if false.
280pub fn check(condition: bool, msg: impl std::fmt::Display) -> Result<()> {
281    if condition {
282        Ok(())
283    } else {
284        Err(Error::general(msg))
285    }
286}
287
288#[cfg(test)]
289mod error_tests {
290    use super::*;
291
292    #[test]
293    fn test_general_error_message() {
294        let e = Error::general("something went wrong");
295        assert_eq!(e.to_string(), "something went wrong");
296    }
297
298    #[test]
299    fn test_buffer_allocation_failed_message() {
300        let e = Error::BufferAllocationFailed {
301            requested_bytes: 1024,
302            available_bytes: 512,
303        };
304        let msg = e.to_string();
305        assert!(msg.contains("1024"), "should mention requested bytes");
306        assert!(msg.contains("512"), "should mention available bytes");
307        assert!(e.is_allocation_error());
308    }
309
310    #[test]
311    fn test_invalid_buffer_handle() {
312        let e = Error::InvalidBufferHandle(42);
313        assert!(e.to_string().contains("42"));
314    }
315
316    #[test]
317    fn test_shader_compilation_error() {
318        let e = Error::ShaderCompilationError {
319            shader: "sph_density".to_string(),
320            message: "undefined symbol".to_string(),
321        };
322        let msg = e.to_string();
323        assert!(msg.contains("sph_density"));
324        assert!(msg.contains("undefined symbol"));
325        assert!(e.is_shader_error());
326    }
327
328    #[test]
329    fn test_dispatch_limit_exceeded() {
330        let e = Error::DispatchLimitExceeded {
331            dispatch_size: 100_000,
332            limit: 65535,
333        };
334        let msg = e.to_string();
335        assert!(msg.contains("100000"));
336        assert!(msg.contains("65535"));
337    }
338
339    #[test]
340    fn test_grid_index_out_of_bounds() {
341        let e = Error::GridIndexOutOfBounds {
342            i: 10,
343            j: 5,
344            k: 3,
345            nx: 8,
346            ny: 8,
347            nz: 8,
348        };
349        let msg = e.to_string();
350        assert!(msg.contains("10"));
351        assert!(msg.contains('8'.to_string().as_str()));
352    }
353
354    #[test]
355    fn test_is_not_shader_error() {
356        let e = Error::general("not a shader error");
357        assert!(!e.is_shader_error());
358    }
359
360    #[test]
361    fn test_unsupported_feature() {
362        let e = Error::UnsupportedFeature {
363            feature: "ray_tracing".to_string(),
364        };
365        assert!(e.to_string().contains("ray_tracing"));
366    }
367
368    // ── New error variant / helper tests ─────────────────────────────────
369
370    #[test]
371    fn test_is_grid_error() {
372        let e = grid_oob_err(1, 2, 3, 4, 5, 6);
373        assert!(e.is_grid_error());
374        assert!(!e.is_allocation_error());
375    }
376
377    #[test]
378    fn test_is_arg_mismatch() {
379        let e = arg_mismatch_err("test_kernel", 3, 2);
380        assert!(e.is_arg_mismatch());
381        assert!(!e.is_shader_error());
382    }
383
384    #[test]
385    fn test_is_unsupported() {
386        let e = unsupported_err("ray_tracing");
387        assert!(e.is_unsupported());
388    }
389
390    #[test]
391    fn test_in_stage_wraps_error() {
392        let e = Error::general("boom");
393        let wrapped = e.in_stage("sph_density");
394        assert!(wrapped.to_string().contains("sph_density"));
395        assert!(wrapped.to_string().contains("boom"));
396    }
397
398    #[test]
399    fn test_alloc_err_convenience() {
400        let e = alloc_err(512, 256);
401        assert!(e.is_allocation_error());
402        assert!(e.to_string().contains("512"));
403    }
404
405    #[test]
406    fn test_dispatch_limit_err_convenience() {
407        let e = dispatch_limit_err(99999, 65535);
408        assert!(e.to_string().contains("99999"));
409    }
410
411    #[test]
412    fn test_shader_err_convenience() {
413        let e = shader_err("my_shader", "syntax error");
414        assert!(e.is_shader_error());
415        assert!(e.to_string().contains("syntax error"));
416    }
417
418    #[test]
419    fn test_into_err() {
420        let result: Result<i32> = Error::general("nope").into_err();
421        assert!(result.is_err());
422    }
423
424    #[test]
425    fn test_collect_errors_empty() {
426        assert!(collect_errors(vec![]).is_ok());
427    }
428
429    #[test]
430    fn test_collect_errors_nonempty() {
431        let errs = vec![Error::general("first"), Error::general("second")];
432        let result = collect_errors(errs);
433        assert!(result.is_err());
434        assert!(result.unwrap_err().to_string().contains("first"));
435    }
436
437    #[test]
438    fn test_check_passes() {
439        assert!(check(true, "should not fail").is_ok());
440    }
441
442    #[test]
443    fn test_check_fails() {
444        let r = check(false, "condition violated");
445        assert!(r.is_err());
446        assert!(r.unwrap_err().to_string().contains("condition violated"));
447    }
448
449    #[test]
450    fn test_annotated_error_fatal_display() {
451        let e = Error::general("crash");
452        let ann = e.fatal(Some("sph_kernel"));
453        let s = ann.to_string();
454        assert!(s.contains("FATAL"));
455        assert!(s.contains("sph_kernel"));
456        assert!(s.contains("crash"));
457    }
458
459    #[test]
460    fn test_annotated_error_warning_no_kernel() {
461        let e = Error::general("degraded");
462        let ann = e.warning(None);
463        let s = ann.to_string();
464        assert!(s.contains("WARNING"));
465        assert!(s.contains("degraded"));
466    }
467
468    #[test]
469    fn test_error_severity_ordering() {
470        assert!(ErrorSeverity::Info < ErrorSeverity::Warning);
471        assert!(ErrorSeverity::Warning < ErrorSeverity::Fatal);
472    }
473
474    #[test]
475    fn test_error_severity_display() {
476        assert_eq!(ErrorSeverity::Info.to_string(), "INFO");
477        assert_eq!(ErrorSeverity::Warning.to_string(), "WARNING");
478        assert_eq!(ErrorSeverity::Fatal.to_string(), "FATAL");
479    }
480}