Skip to main content

etensor_core/
errors.rs

1//! Unified error handling and boundary enforcement for ETensor.
2
3/// The standard result type utilized across the entire ETensor engine.
4///
5/// By wrapping all backend and autograd operations in this Result, we guarantee
6/// zero implicit panics, allowing graceful error handoffs to Python bindings.
7pub type EtensorResult<T> = Result<T, EtensorError>;
8
9/// Represents all possible failure states within the ETensor execution pipeline.
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum EtensorError {
12    /// Raised when an operation receives tensors with fundamentally incompatible dimensional geometry.
13    ShapeMismatch {
14        expected: Vec<usize>,
15        got: Vec<usize>,
16    },
17
18    /// Raised when mathematical operations are attempted across isolated memory spaces.
19    /// (e.g., Attempting to add a CPU buffer directly to a CudaNative buffer).
20    DeviceMismatch { expected: String, got: String },
21
22    /// Raised when a kernel does not support the provided precision type.
23    DTypeMismatch { expected: String, got: String },
24
25    /// Raised when a failure occurs during Tape recording, topological sorting, or gradient accumulation.
26    AutogradError(String),
27
28    /// Raised when a hardware backend (CUDA driver, Torch C++, Rayon) fails an internal execution.
29    BackendError(String),
30
31    /// Raised during serialization or deserialization of model weights (e.g., Safetensors).
32    IoError(String),
33
34    /// A generic fallback for unclassified engine structural violations.
35    InternalError(String),
36}
37
38impl std::fmt::Display for EtensorError {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        match self {
41            EtensorError::ShapeMismatch { expected, got } => {
42                write!(
43                    f,
44                    "ShapeMismatch: Operation expected shape {:?}, but got {:?}",
45                    expected, got
46                )
47            }
48            EtensorError::DeviceMismatch { expected, got } => {
49                write!(
50                    f,
51                    "DeviceMismatch: Operation expected tensors on '{}', but found tensor on '{}'",
52                    expected, got
53                )
54            }
55            EtensorError::DTypeMismatch { expected, got } => {
56                write!(
57                    f,
58                    "DTypeMismatch: Operation expected dtype '{}', but found '{}'",
59                    expected, got
60                )
61            }
62            EtensorError::AutogradError(msg) => write!(f, "AutogradError: {}", msg),
63            EtensorError::BackendError(msg) => write!(f, "BackendError: {}", msg),
64            EtensorError::IoError(msg) => write!(f, "IoError: {}", msg),
65            EtensorError::InternalError(msg) => write!(f, "InternalError: {}", msg),
66        }
67    }
68}
69
70// Implements the standard Rust Error trait, allowing seamless integration with
71// standard library traits (like `?` propagation and `Box<dyn Error>`).
72impl std::error::Error for EtensorError {}
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77
78    #[test]
79    fn test_shape_mismatch_formatting() {
80        let err = EtensorError::ShapeMismatch {
81            expected: vec![2, 3],
82            got: vec![3, 2],
83        };
84        assert_eq!(
85            err.to_string(),
86            "ShapeMismatch: Operation expected shape [2, 3], but got [3, 2]"
87        );
88    }
89
90    #[test]
91    fn test_device_mismatch_formatting() {
92        let err = EtensorError::DeviceMismatch {
93            expected: "cuda:0".to_string(),
94            got: "cpu".to_string(),
95        };
96        assert_eq!(
97            err.to_string(),
98            "DeviceMismatch: Operation expected tensors on 'cuda:0', but found tensor on 'cpu'"
99        );
100    }
101
102    #[test]
103    fn test_result_propagation() {
104        // A simple dummy function to test the EtensorResult alias
105        fn mock_add(a_is_cpu: bool, b_is_cpu: bool) -> EtensorResult<f32> {
106            if a_is_cpu != b_is_cpu {
107                return Err(EtensorError::DeviceMismatch {
108                    expected: "cpu".to_string(),
109                    got: "cuda".to_string(),
110                });
111            }
112            Ok(42.0)
113        }
114
115        assert!(mock_add(true, true).is_ok());
116
117        let err = mock_add(true, false).unwrap_err();
118        assert!(matches!(err, EtensorError::DeviceMismatch { .. }));
119    }
120}