Skip to main content

oxiphysics_gpu/
error.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Error types for oxiphysics-gpu
5
6#![allow(dead_code)]
7
8use thiserror::Error;
9
10/// Main error type for the gpu module.
11#[derive(Debug, Error)]
12pub enum Error {
13    /// Generic error with a free-form message.
14    #[error("{0}")]
15    General(String),
16
17    /// A GPU buffer allocation failed.
18    #[error(
19        "buffer allocation failed: requested {requested_bytes} bytes (available {available_bytes})"
20    )]
21    BufferAllocationFailed {
22        /// Bytes requested.
23        requested_bytes: usize,
24        /// Bytes actually available.
25        available_bytes: usize,
26    },
27
28    /// An invalid buffer handle was used.
29    #[error("invalid buffer handle: {0}")]
30    InvalidBufferHandle(usize),
31
32    /// A shader compilation error (mock).
33    #[error("shader compilation error in '{shader}': {message}")]
34    ShaderCompilationError {
35        /// Name of the offending shader.
36        shader: String,
37        /// Compiler message.
38        message: String,
39    },
40
41    /// A dispatch exceeded the hardware work-group limit.
42    #[error("dispatch size {dispatch_size} exceeds hardware limit {limit}")]
43    DispatchLimitExceeded {
44        /// Requested dispatch size (number of work-groups).
45        dispatch_size: usize,
46        /// Hardware maximum.
47        limit: usize,
48    },
49
50    /// Out-of-bounds grid access.
51    #[error("grid index ({i}, {j}, {k}) out of bounds for grid ({nx}, {ny}, {nz})")]
52    GridIndexOutOfBounds {
53        /// Requested x index.
54        i: usize,
55        /// Requested y index.
56        j: usize,
57        /// Requested z index.
58        k: usize,
59        /// Grid x dimension.
60        nx: usize,
61        /// Grid y dimension.
62        ny: usize,
63        /// Grid z dimension.
64        nz: usize,
65    },
66
67    /// A kernel argument count mismatch.
68    #[error("kernel '{kernel}' expects {expected} arguments but got {got}")]
69    KernelArgCountMismatch {
70        /// Kernel name.
71        kernel: String,
72        /// Expected number of arguments.
73        expected: usize,
74        /// Provided number of arguments.
75        got: usize,
76    },
77
78    /// An unsupported backend feature was requested.
79    #[error("unsupported feature: {feature}")]
80    UnsupportedFeature {
81        /// Description of the unsupported feature.
82        feature: String,
83    },
84}
85
86/// A pipeline-stage error: carries the stage name plus the underlying cause.
87#[derive(Debug, Error)]
88#[error("pipeline stage '{stage}' failed: {source}")]
89pub struct PipelineStageError {
90    /// Name of the pipeline stage (e.g. `"vertex_fetch"`, `"sph_density"`).
91    pub stage: String,
92    /// Root cause.
93    pub source: Box<Error>,
94}
95
96/// Severity level for a GPU error.
97#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
98pub enum ErrorSeverity {
99    /// Informational — execution can continue.
100    Info,
101    /// Warning — partial results may be degraded.
102    Warning,
103    /// Fatal — must abort current dispatch.
104    Fatal,
105}
106
107impl std::fmt::Display for ErrorSeverity {
108    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109        match self {
110            ErrorSeverity::Info => write!(f, "INFO"),
111            ErrorSeverity::Warning => write!(f, "WARNING"),
112            ErrorSeverity::Fatal => write!(f, "FATAL"),
113        }
114    }
115}
116
117/// An error annotated with severity and an optional kernel name.
118#[derive(Debug)]
119pub struct AnnotatedError {
120    /// The underlying error.
121    pub error: Error,
122    /// Severity classification.
123    pub severity: ErrorSeverity,
124    /// Optional kernel that triggered the error.
125    pub kernel: Option<String>,
126}
127
128impl std::fmt::Display for AnnotatedError {
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        if let Some(ref k) = self.kernel {
131            write!(f, "[{}] kernel '{}': {}", self.severity, k, self.error)
132        } else {
133            write!(f, "[{}] {}", self.severity, self.error)
134        }
135    }
136}
137
138impl AnnotatedError {
139    /// Wrap an error as fatal with an optional kernel label.
140    pub fn fatal(error: Error, kernel: Option<&str>) -> Self {
141        Self {
142            error,
143            severity: ErrorSeverity::Fatal,
144            kernel: kernel.map(str::to_string),
145        }
146    }
147
148    /// Wrap an error as a warning.
149    pub fn warning(error: Error, kernel: Option<&str>) -> Self {
150        Self {
151            error,
152            severity: ErrorSeverity::Warning,
153            kernel: kernel.map(str::to_string),
154        }
155    }
156}
157
158/// Result type alias
159pub type Result<T> = std::result::Result<T, Error>;
160
161/// GPU operation error (used by `LbmGpuSolver`, `BvhGpuTraverser`, etc.).
162#[derive(Debug, Error)]
163pub enum GpuError {
164    /// GPU backend initialisation failed.
165    #[error("GPU backend init failed: {0}")]
166    BackendInit(String),
167
168    /// Shader dispatch / pipeline error.
169    #[error("shader dispatch error: {0}")]
170    ShaderDispatch(String),
171
172    /// Read-back from the GPU buffer failed.
173    #[error("GPU buffer read-back failed: {0}")]
174    ReadBack(String),
175
176    /// An invalid buffer handle was used.
177    #[error("invalid GPU buffer handle: {0}")]
178    InvalidHandle(usize),
179}
180
181impl Error {
182    /// Construct a [`Error::General`] from any `Display`-able value.
183    pub fn general(msg: impl std::fmt::Display) -> Self {
184        Error::General(msg.to_string())
185    }
186
187    /// True when this is a recoverable allocation error.
188    pub fn is_allocation_error(&self) -> bool {
189        matches!(self, Error::BufferAllocationFailed { .. })
190    }
191
192    /// True when this is a shader compilation error.
193    pub fn is_shader_error(&self) -> bool {
194        matches!(self, Error::ShaderCompilationError { .. })
195    }
196
197    /// True when this is an out-of-bounds grid error.
198    pub fn is_grid_error(&self) -> bool {
199        matches!(self, Error::GridIndexOutOfBounds { .. })
200    }
201
202    /// True when this is a kernel argument mismatch error.
203    pub fn is_arg_mismatch(&self) -> bool {
204        matches!(self, Error::KernelArgCountMismatch { .. })
205    }
206
207    /// True when this is an unsupported feature error.
208    pub fn is_unsupported(&self) -> bool {
209        matches!(self, Error::UnsupportedFeature { .. })
210    }
211
212    /// Wrap this error in a [`PipelineStageError`].
213    pub fn in_stage(self, stage: impl Into<String>) -> PipelineStageError {
214        PipelineStageError {
215            stage: stage.into(),
216            source: Box::new(self),
217        }
218    }
219
220    /// Annotate with fatal severity.
221    pub fn fatal(self, kernel: Option<&str>) -> AnnotatedError {
222        AnnotatedError::fatal(self, kernel)
223    }
224
225    /// Annotate with warning severity.
226    pub fn warning(self, kernel: Option<&str>) -> AnnotatedError {
227        AnnotatedError::warning(self, kernel)
228    }
229
230    /// Convert into a `Result::Err`.
231    pub fn into_err<T>(self) -> Result<T> {
232        Err(self)
233    }
234}
235
236// ── Convenience constructors ─────────────────────────────────────────────────
237
238/// Build a [`Error::BufferAllocationFailed`] error.
239pub fn alloc_err(requested_bytes: usize, available_bytes: usize) -> Error {
240    Error::BufferAllocationFailed {
241        requested_bytes,
242        available_bytes,
243    }
244}
245
246/// Build a [`Error::KernelArgCountMismatch`] error.
247pub fn arg_mismatch_err(kernel: impl Into<String>, expected: usize, got: usize) -> Error {
248    Error::KernelArgCountMismatch {
249        kernel: kernel.into(),
250        expected,
251        got,
252    }
253}
254
255/// Build a [`Error::GridIndexOutOfBounds`] error.
256#[allow(clippy::too_many_arguments)]
257pub fn grid_oob_err(i: usize, j: usize, k: usize, nx: usize, ny: usize, nz: usize) -> Error {
258    Error::GridIndexOutOfBounds {
259        i,
260        j,
261        k,
262        nx,
263        ny,
264        nz,
265    }
266}
267
268/// Build a [`Error::DispatchLimitExceeded`] error.
269pub fn dispatch_limit_err(dispatch_size: usize, limit: usize) -> Error {
270    Error::DispatchLimitExceeded {
271        dispatch_size,
272        limit,
273    }
274}
275
276/// Build a [`Error::ShaderCompilationError`].
277pub fn shader_err(shader: impl Into<String>, message: impl Into<String>) -> Error {
278    Error::ShaderCompilationError {
279        shader: shader.into(),
280        message: message.into(),
281    }
282}
283
284/// Build an [`Error::UnsupportedFeature`].
285pub fn unsupported_err(feature: impl Into<String>) -> Error {
286    Error::UnsupportedFeature {
287        feature: feature.into(),
288    }
289}
290
291// ── Error collection ─────────────────────────────────────────────────────────
292
293/// Collect multiple errors from a batch dispatch.  Returns `Ok(())` if the
294/// vec is empty, or `Err` containing the first error otherwise.
295pub fn collect_errors(errors: Vec<Error>) -> Result<()> {
296    errors.into_iter().next().map_or(Ok(()), Err)
297}
298
299/// Check a boolean condition; return `Err(Error::General(msg))` if false.
300pub fn check(condition: bool, msg: impl std::fmt::Display) -> Result<()> {
301    if condition {
302        Ok(())
303    } else {
304        Err(Error::general(msg))
305    }
306}
307
308#[cfg(test)]
309mod error_tests {
310    use super::*;
311
312    #[test]
313    fn test_general_error_message() {
314        let e = Error::general("something went wrong");
315        assert_eq!(e.to_string(), "something went wrong");
316    }
317
318    #[test]
319    fn test_buffer_allocation_failed_message() {
320        let e = Error::BufferAllocationFailed {
321            requested_bytes: 1024,
322            available_bytes: 512,
323        };
324        let msg = e.to_string();
325        assert!(msg.contains("1024"), "should mention requested bytes");
326        assert!(msg.contains("512"), "should mention available bytes");
327        assert!(e.is_allocation_error());
328    }
329
330    #[test]
331    fn test_invalid_buffer_handle() {
332        let e = Error::InvalidBufferHandle(42);
333        assert!(e.to_string().contains("42"));
334    }
335
336    #[test]
337    fn test_shader_compilation_error() {
338        let e = Error::ShaderCompilationError {
339            shader: "sph_density".to_string(),
340            message: "undefined symbol".to_string(),
341        };
342        let msg = e.to_string();
343        assert!(msg.contains("sph_density"));
344        assert!(msg.contains("undefined symbol"));
345        assert!(e.is_shader_error());
346    }
347
348    #[test]
349    fn test_dispatch_limit_exceeded() {
350        let e = Error::DispatchLimitExceeded {
351            dispatch_size: 100_000,
352            limit: 65535,
353        };
354        let msg = e.to_string();
355        assert!(msg.contains("100000"));
356        assert!(msg.contains("65535"));
357    }
358
359    #[test]
360    fn test_grid_index_out_of_bounds() {
361        let e = Error::GridIndexOutOfBounds {
362            i: 10,
363            j: 5,
364            k: 3,
365            nx: 8,
366            ny: 8,
367            nz: 8,
368        };
369        let msg = e.to_string();
370        assert!(msg.contains("10"));
371        assert!(msg.contains('8'.to_string().as_str()));
372    }
373
374    #[test]
375    fn test_is_not_shader_error() {
376        let e = Error::general("not a shader error");
377        assert!(!e.is_shader_error());
378    }
379
380    #[test]
381    fn test_unsupported_feature() {
382        let e = Error::UnsupportedFeature {
383            feature: "ray_tracing".to_string(),
384        };
385        assert!(e.to_string().contains("ray_tracing"));
386    }
387
388    // ── New error variant / helper tests ─────────────────────────────────
389
390    #[test]
391    fn test_is_grid_error() {
392        let e = grid_oob_err(1, 2, 3, 4, 5, 6);
393        assert!(e.is_grid_error());
394        assert!(!e.is_allocation_error());
395    }
396
397    #[test]
398    fn test_is_arg_mismatch() {
399        let e = arg_mismatch_err("test_kernel", 3, 2);
400        assert!(e.is_arg_mismatch());
401        assert!(!e.is_shader_error());
402    }
403
404    #[test]
405    fn test_is_unsupported() {
406        let e = unsupported_err("ray_tracing");
407        assert!(e.is_unsupported());
408    }
409
410    #[test]
411    fn test_in_stage_wraps_error() {
412        let e = Error::general("boom");
413        let wrapped = e.in_stage("sph_density");
414        assert!(wrapped.to_string().contains("sph_density"));
415        assert!(wrapped.to_string().contains("boom"));
416    }
417
418    #[test]
419    fn test_alloc_err_convenience() {
420        let e = alloc_err(512, 256);
421        assert!(e.is_allocation_error());
422        assert!(e.to_string().contains("512"));
423    }
424
425    #[test]
426    fn test_dispatch_limit_err_convenience() {
427        let e = dispatch_limit_err(99999, 65535);
428        assert!(e.to_string().contains("99999"));
429    }
430
431    #[test]
432    fn test_shader_err_convenience() {
433        let e = shader_err("my_shader", "syntax error");
434        assert!(e.is_shader_error());
435        assert!(e.to_string().contains("syntax error"));
436    }
437
438    #[test]
439    fn test_into_err() {
440        let result: Result<i32> = Error::general("nope").into_err();
441        assert!(result.is_err());
442    }
443
444    #[test]
445    fn test_collect_errors_empty() {
446        assert!(collect_errors(vec![]).is_ok());
447    }
448
449    #[test]
450    fn test_collect_errors_nonempty() {
451        let errs = vec![Error::general("first"), Error::general("second")];
452        let result = collect_errors(errs);
453        assert!(result.is_err());
454        assert!(result.unwrap_err().to_string().contains("first"));
455    }
456
457    #[test]
458    fn test_check_passes() {
459        assert!(check(true, "should not fail").is_ok());
460    }
461
462    #[test]
463    fn test_check_fails() {
464        let r = check(false, "condition violated");
465        assert!(r.is_err());
466        assert!(r.unwrap_err().to_string().contains("condition violated"));
467    }
468
469    #[test]
470    fn test_annotated_error_fatal_display() {
471        let e = Error::general("crash");
472        let ann = e.fatal(Some("sph_kernel"));
473        let s = ann.to_string();
474        assert!(s.contains("FATAL"));
475        assert!(s.contains("sph_kernel"));
476        assert!(s.contains("crash"));
477    }
478
479    #[test]
480    fn test_annotated_error_warning_no_kernel() {
481        let e = Error::general("degraded");
482        let ann = e.warning(None);
483        let s = ann.to_string();
484        assert!(s.contains("WARNING"));
485        assert!(s.contains("degraded"));
486    }
487
488    #[test]
489    fn test_error_severity_ordering() {
490        assert!(ErrorSeverity::Info < ErrorSeverity::Warning);
491        assert!(ErrorSeverity::Warning < ErrorSeverity::Fatal);
492    }
493
494    #[test]
495    fn test_error_severity_display() {
496        assert_eq!(ErrorSeverity::Info.to_string(), "INFO");
497        assert_eq!(ErrorSeverity::Warning.to_string(), "WARNING");
498        assert_eq!(ErrorSeverity::Fatal.to_string(), "FATAL");
499    }
500}