Skip to main content

god_gragh/tensor/
error.rs

1//! Tensor 错误类型
2//!
3//! 定义 tensor 操作中可能出现的各种错误
4
5use core::fmt;
6
7/// Tensor 操作错误
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub enum TensorError {
10    /// 形状不匹配错误
11    ShapeMismatch {
12        /// 期望的形状
13        expected: Vec<usize>,
14        /// 实际得到的形状
15        got: Vec<usize>,
16    },
17
18    /// 索引越界错误
19    IndexOutOfBounds {
20        /// 越界的索引
21        index: usize,
22        /// 维度
23        dim: usize,
24        /// 该维度的大小
25        size: usize,
26    },
27
28    /// 维度不匹配错误
29    DimensionMismatch {
30        /// 期望的维度数
31        expected: usize,
32        /// 实际得到的维度数
33        got: usize,
34    },
35
36    /// 数据类型不支持错误
37    UnsupportedDType {
38        /// 不支持的数据类型
39        dtype: String,
40        /// 操作名称
41        operation: String,
42    },
43
44    /// 设备不支持错误
45    UnsupportedDevice {
46        /// 不支持的设备
47        device: String,
48    },
49
50    /// 内存不足错误
51    OutOfMemory {
52        /// 请求的字节数
53        requested_bytes: usize,
54    },
55
56    /// BLAS 错误
57    BlasError {
58        /// BLAS 错误码
59        code: i32,
60        /// 错误描述
61        description: String,
62    },
63
64    /// 稀疏格式转换错误
65    SparseFormatError {
66        /// 源格式
67        from: String,
68        /// 目标格式
69        to: String,
70        /// 错误描述
71        description: String,
72    },
73
74    /// 广播错误
75    BroadcastError {
76        /// 第一个形状
77        shape1: Vec<usize>,
78        /// 第二个形状
79        shape2: Vec<usize>,
80    },
81
82    /// 切片错误
83    SliceError {
84        /// 错误描述
85        description: String,
86    },
87
88    /// 设备间传输错误
89    DeviceTransferError {
90        /// 源设备
91        from: String,
92        /// 目标设备
93        to: String,
94        /// 错误描述
95        description: String,
96    },
97}
98
99impl fmt::Display for TensorError {
100    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101        match self {
102            TensorError::ShapeMismatch { expected, got } => {
103                write!(f, "Shape mismatch: expected {:?}, got {:?}", expected, got)
104            }
105            TensorError::IndexOutOfBounds { index, dim, size } => {
106                write!(f, "Index {} out of bounds for dimension {} (size {})", index, dim, size)
107            }
108            TensorError::DimensionMismatch { expected, got } => {
109                write!(f, "Dimension mismatch: expected {}D, got {}D", expected, got)
110            }
111            TensorError::UnsupportedDType { dtype, operation } => {
112                write!(f, "Unsupported dtype {} for operation {}", dtype, operation)
113            }
114            TensorError::UnsupportedDevice { device } => {
115                write!(f, "Unsupported device: {}", device)
116            }
117            TensorError::OutOfMemory { requested_bytes } => {
118                write!(f, "Out of memory: requested {} bytes", requested_bytes)
119            }
120            TensorError::BlasError { code, description } => {
121                write!(f, "BLAS error (code {}): {}", code, description)
122            }
123            TensorError::SparseFormatError { from, to, description } => {
124                write!(f, "Sparse format error converting {} to {}: {}", from, to, description)
125            }
126            TensorError::BroadcastError { shape1, shape2 } => {
127                write!(f, "Cannot broadcast shapes {:?} and {:?}", shape1, shape2)
128            }
129            TensorError::SliceError { description } => {
130                write!(f, "Slice error: {}", description)
131            }
132            TensorError::DeviceTransferError { from, to, description } => {
133                write!(f, "Device transfer error from {} to {}: {}", from, to, description)
134            }
135        }
136    }
137}
138
139#[cfg(feature = "std")]
140impl std::error::Error for TensorError {}
141
142/// Tensor 操作结果类型别名
143pub type TensorResult<T> = Result<T, TensorError>;
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn test_error_display() {
151        let err = TensorError::ShapeMismatch {
152            expected: vec![2, 3],
153            got: vec![3, 2],
154        };
155        assert_eq!(
156            format!("{}", err),
157            "Shape mismatch: expected [2, 3], got [3, 2]"
158        );
159
160        let err = TensorError::IndexOutOfBounds {
161            index: 5,
162            dim: 0,
163            size: 3,
164        };
165        assert_eq!(
166            format!("{}", err),
167            "Index 5 out of bounds for dimension 0 (size 3)"
168        );
169    }
170
171    #[test]
172    fn test_error_equality() {
173        let err1 = TensorError::ShapeMismatch {
174            expected: vec![2, 3],
175            got: vec![3, 2],
176        };
177        let err2 = TensorError::ShapeMismatch {
178            expected: vec![2, 3],
179            got: vec![3, 2],
180        };
181        assert_eq!(err1, err2);
182
183        let err3 = TensorError::ShapeMismatch {
184            expected: vec![2, 3],
185            got: vec![2, 3],
186        };
187        assert_ne!(err1, err3);
188    }
189}