Skip to main content

cubek_reduce/
error.rs

1use cubecl::{ir::StorageType, server::LaunchError};
2use thiserror::Error;
3
4#[derive(Error, Debug, Clone)]
5/// This error should be caught and properly handled.
6pub enum ReduceError {
7    /// Indicate that the hardware / API doesn't support SIMT plane instructions.
8    #[error(
9        "Trying to launch a kernel using plane instructions, but there are not supported by the hardware."
10    )]
11    PlanesUnavailable,
12    /// When the cube count is bigger than the max supported.
13    #[error("The cube count is larger than the max supported.")]
14    CubeCountTooLarge,
15
16    /// A generic validation error
17    #[error("A generic validation error: {details}")]
18    Validation { details: &'static str },
19
20    /// Indicate that min_plane_dim != max_plane_dim, thus the exact plane_dim is not fixed.
21    #[error(
22        "Trying to launch a kernel using plane instructions, but the min and max plane dimensions are different."
23    )]
24    ImprecisePlaneDim,
25    /// Indicate the axis is too large.
26    #[error("The provided axis ({axis}) must be smaller than the input tensor rank ({rank}).")]
27    InvalidAxis { axis: usize, rank: usize },
28    /// Indicate that the shape of the input tensor is too small for the given input and axis.
29    #[error(
30        "The input reduce axis length (currently {axis_length:?}) should be at least k ({k:?})."
31    )]
32    ReduceAxisTooSmall { axis_length: usize, k: usize },
33    /// Indicate that the shape of the output tensor is invalid for the given input and axis.
34    #[error("The output shape (currently {output_shape:?}) should be {expected_shape:?}.")]
35    MismatchOutputShape {
36        expected_shape: Vec<usize>,
37        output_shape: Vec<usize>,
38    },
39    /// Indicate that we can't launch a shared sum because the atomic addition is not supported.
40    #[error("Atomic add not supported by the client for {0}")]
41    MissingAtomicAdd(StorageType),
42
43    /// An error happened during launch.
44    #[error("An error happened during launch\nCaused by:\n  {0}")]
45    Launch(LaunchError),
46}