hpt_common/error/base.rs
1use thiserror::Error;
2
3use super::{
4 autograd::AutogradError, common::CommonError, device::DeviceError, kernel::KernelError,
5 memory::MemoryError, param::ParamError, random::RandomError, shape::ShapeError,
6};
7
8/// Base error type for all tensor operations
9#[derive(Debug, Error)]
10pub enum TensorError {
11 /// Shape-related errors such as dimension mismatch, broadcasting errors
12 #[error(transparent)]
13 Shape(#[from] ShapeError),
14
15 /// Device-related errors such as device not found, CUDA errors
16 #[error(transparent)]
17 Device(#[from] DeviceError),
18
19 /// Memory-related errors such as memory allocation failed, invalid memory layout
20 #[error(transparent)]
21 Memory(#[from] MemoryError),
22
23 /// Kernel-related errors such as kernel compilation failed, kernel execution failed
24 #[error(transparent)]
25 Kernel(#[from] KernelError),
26
27 /// Parameter-related errors such as invalid function arguments
28 #[error(transparent)]
29 Param(#[from] ParamError),
30
31 /// Autograd-related errors such as inplace computation is not allowed
32 #[error(transparent)]
33 Autograd(#[from] AutogradError),
34
35 /// Random distribution-related errors such as invalid distribution parameters
36 #[error(transparent)]
37 Random(#[from] RandomError),
38
39 /// Common errors such as lock failed
40 #[error(transparent)]
41 Common(#[from] CommonError),
42}