Skip to main content

god_graph/tensor/
error.rs

1//! Tensor 错误类型
2//!
3//! 定义 tensor 操作中可能出现的各种错误
4
5use core::fmt;
6
7/// Tensor 操作错误
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub enum TensorError {
10    /// 形状不匹配错误
11    ShapeMismatch {
12        /// 期望的形状
13        expected: Vec<usize>,
14        /// 实际得到的形状
15        got: Vec<usize>,
16    },
17
18    /// 索引越界错误
19    IndexOutOfBounds {
20        /// 越界的索引
21        index: usize,
22        /// 维度
23        dim: usize,
24        /// 该维度的大小
25        size: usize,
26    },
27
28    /// 维度不匹配错误
29    DimensionMismatch {
30        /// 期望的维度数
31        expected: usize,
32        /// 实际得到的维度数
33        got: usize,
34    },
35
36    /// 数据类型不支持错误
37    UnsupportedDType {
38        /// 不支持的数据类型
39        dtype: String,
40        /// 操作名称
41        operation: String,
42    },
43
44    /// 设备不支持错误
45    UnsupportedDevice {
46        /// 不支持的设备
47        device: String,
48    },
49
50    /// 内存不足错误
51    OutOfMemory {
52        /// 请求的字节数
53        requested_bytes: usize,
54    },
55
56    /// BLAS 错误
57    BlasError {
58        /// BLAS 错误码
59        code: i32,
60        /// 错误描述
61        description: String,
62    },
63
64    /// 稀疏格式转换错误
65    SparseFormatError {
66        /// 源格式
67        from: String,
68        /// 目标格式
69        to: String,
70        /// 错误描述
71        description: String,
72    },
73
74    /// 广播错误
75    BroadcastError {
76        /// 第一个形状
77        shape1: Vec<usize>,
78        /// 第二个形状
79        shape2: Vec<usize>,
80    },
81
82    /// 切片错误
83    SliceError {
84        /// 错误描述
85        description: String,
86    },
87
88    /// 设备间传输错误
89    DeviceTransferError {
90        /// 源设备
91        from: String,
92        /// 目标设备
93        to: String,
94        /// 错误描述
95        description: String,
96    },
97
98    /// 内存分配错误
99    AllocationError {
100        /// 错误描述
101        message: String,
102    },
103
104    /// 矩阵操作错误
105    MatrixError {
106        /// 错误描述
107        message: String,
108    },
109}
110
111impl fmt::Display for TensorError {
112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113        match self {
114            TensorError::ShapeMismatch { expected, got } => {
115                write!(f, "Shape mismatch: expected {:?}, got {:?}", expected, got)
116            }
117            TensorError::IndexOutOfBounds { index, dim, size } => {
118                write!(
119                    f,
120                    "Index {} out of bounds for dimension {} (size {})",
121                    index, dim, size
122                )
123            }
124            TensorError::DimensionMismatch { expected, got } => {
125                write!(
126                    f,
127                    "Dimension mismatch: expected {}D, got {}D",
128                    expected, got
129                )
130            }
131            TensorError::UnsupportedDType { dtype, operation } => {
132                write!(f, "Unsupported dtype {} for operation {}", dtype, operation)
133            }
134            TensorError::UnsupportedDevice { device } => {
135                write!(f, "Unsupported device: {}", device)
136            }
137            TensorError::OutOfMemory { requested_bytes } => {
138                write!(f, "Out of memory: requested {} bytes", requested_bytes)
139            }
140            TensorError::BlasError { code, description } => {
141                write!(f, "BLAS error (code {}): {}", code, description)
142            }
143            TensorError::SparseFormatError {
144                from,
145                to,
146                description,
147            } => {
148                write!(
149                    f,
150                    "Sparse format error converting {} to {}: {}",
151                    from, to, description
152                )
153            }
154            TensorError::BroadcastError { shape1, shape2 } => {
155                write!(f, "Cannot broadcast shapes {:?} and {:?}", shape1, shape2)
156            }
157            TensorError::SliceError { description } => {
158                write!(f, "Slice error: {}", description)
159            }
160            TensorError::DeviceTransferError {
161                from,
162                to,
163                description,
164            } => {
165                write!(
166                    f,
167                    "Device transfer error from {} to {}: {}",
168                    from, to, description
169                )
170            }
171            TensorError::AllocationError { message } => {
172                write!(f, "Allocation error: {}", message)
173            }
174            TensorError::MatrixError { message } => {
175                write!(f, "Matrix error: {}", message)
176            }
177        }
178    }
179}
180
181#[cfg(feature = "std")]
182impl std::error::Error for TensorError {}
183
184/// Tensor 操作结果类型别名
185pub type TensorResult<T> = Result<T, TensorError>;
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn test_error_display() {
193        let err = TensorError::ShapeMismatch {
194            expected: vec![2, 3],
195            got: vec![3, 2],
196        };
197        assert_eq!(
198            format!("{}", err),
199            "Shape mismatch: expected [2, 3], got [3, 2]"
200        );
201
202        let err = TensorError::IndexOutOfBounds {
203            index: 5,
204            dim: 0,
205            size: 3,
206        };
207        assert_eq!(
208            format!("{}", err),
209            "Index 5 out of bounds for dimension 0 (size 3)"
210        );
211    }
212
213    #[test]
214    fn test_error_equality() {
215        let err1 = TensorError::ShapeMismatch {
216            expected: vec![2, 3],
217            got: vec![3, 2],
218        };
219        let err2 = TensorError::ShapeMismatch {
220            expected: vec![2, 3],
221            got: vec![3, 2],
222        };
223        assert_eq!(err1, err2);
224
225        let err3 = TensorError::ShapeMismatch {
226            expected: vec![2, 3],
227            got: vec![2, 3],
228        };
229        assert_ne!(err1, err3);
230    }
231}