tenrso_kernels/
error.rs

1//! Error types for tensor kernel operations
2//!
3//! This module provides structured error types for kernel operations,
4//! making error handling more robust and informative.
5
6use std::fmt;
7
8/// Error type for tensor kernel operations
9#[derive(Debug, Clone, PartialEq)]
10pub enum KernelError {
11    /// Dimension mismatch between operands
12    DimensionMismatch {
13        operation: String,
14        expected: Vec<usize>,
15        actual: Vec<usize>,
16        context: String,
17    },
18
19    /// Invalid mode/axis specification
20    InvalidMode {
21        mode: usize,
22        max_mode: usize,
23        context: String,
24    },
25
26    /// Rank mismatch (e.g., different CP ranks in factor matrices)
27    RankMismatch {
28        operation: String,
29        expected_rank: usize,
30        actual_rank: usize,
31        factor_index: usize,
32    },
33
34    /// Empty input not allowed
35    EmptyInput {
36        operation: String,
37        parameter: String,
38    },
39
40    /// Invalid tile/block size
41    InvalidTileSize {
42        operation: String,
43        tile_size: usize,
44        reason: String,
45    },
46
47    /// Shape incompatibility
48    IncompatibleShapes {
49        operation: String,
50        shape_a: Vec<usize>,
51        shape_b: Vec<usize>,
52        reason: String,
53    },
54
55    /// Generic operation error with context
56    OperationError { operation: String, message: String },
57}
58
59impl fmt::Display for KernelError {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        match self {
62            KernelError::DimensionMismatch {
63                operation,
64                expected,
65                actual,
66                context,
67            } => write!(
68                f,
69                "{}: dimension mismatch - expected {:?}, got {:?}. {}",
70                operation, expected, actual, context
71            ),
72
73            KernelError::InvalidMode {
74                mode,
75                max_mode,
76                context,
77            } => write!(
78                f,
79                "Invalid mode {}: must be < {}. {}",
80                mode, max_mode, context
81            ),
82
83            KernelError::RankMismatch {
84                operation,
85                expected_rank,
86                actual_rank,
87                factor_index,
88            } => write!(
89                f,
90                "{}: rank mismatch at factor {}: expected rank {}, got {}",
91                operation, factor_index, expected_rank, actual_rank
92            ),
93
94            KernelError::EmptyInput {
95                operation,
96                parameter,
97            } => write!(
98                f,
99                "{}: empty input not allowed for parameter '{}'",
100                operation, parameter
101            ),
102
103            KernelError::InvalidTileSize {
104                operation,
105                tile_size,
106                reason,
107            } => write!(
108                f,
109                "{}: invalid tile size {}: {}",
110                operation, tile_size, reason
111            ),
112
113            KernelError::IncompatibleShapes {
114                operation,
115                shape_a,
116                shape_b,
117                reason,
118            } => write!(
119                f,
120                "{}: incompatible shapes {:?} and {:?}: {}",
121                operation, shape_a, shape_b, reason
122            ),
123
124            KernelError::OperationError { operation, message } => {
125                write!(f, "{}: {}", operation, message)
126            }
127        }
128    }
129}
130
131impl std::error::Error for KernelError {}
132
133/// Result type for kernel operations
134pub type KernelResult<T> = Result<T, KernelError>;
135
136impl KernelError {
137    /// Create a dimension mismatch error
138    pub fn dimension_mismatch(
139        operation: impl Into<String>,
140        expected: Vec<usize>,
141        actual: Vec<usize>,
142        context: impl Into<String>,
143    ) -> Self {
144        KernelError::DimensionMismatch {
145            operation: operation.into(),
146            expected,
147            actual,
148            context: context.into(),
149        }
150    }
151
152    /// Create an invalid mode error
153    pub fn invalid_mode(mode: usize, max_mode: usize, context: impl Into<String>) -> Self {
154        KernelError::InvalidMode {
155            mode,
156            max_mode,
157            context: context.into(),
158        }
159    }
160
161    /// Create a rank mismatch error
162    pub fn rank_mismatch(
163        operation: impl Into<String>,
164        expected_rank: usize,
165        actual_rank: usize,
166        factor_index: usize,
167    ) -> Self {
168        KernelError::RankMismatch {
169            operation: operation.into(),
170            expected_rank,
171            actual_rank,
172            factor_index,
173        }
174    }
175
176    /// Create an empty input error
177    pub fn empty_input(operation: impl Into<String>, parameter: impl Into<String>) -> Self {
178        KernelError::EmptyInput {
179            operation: operation.into(),
180            parameter: parameter.into(),
181        }
182    }
183
184    /// Create an invalid tile size error
185    pub fn invalid_tile_size(
186        operation: impl Into<String>,
187        tile_size: usize,
188        reason: impl Into<String>,
189    ) -> Self {
190        KernelError::InvalidTileSize {
191            operation: operation.into(),
192            tile_size,
193            reason: reason.into(),
194        }
195    }
196
197    /// Create an incompatible shapes error
198    pub fn incompatible_shapes(
199        operation: impl Into<String>,
200        shape_a: Vec<usize>,
201        shape_b: Vec<usize>,
202        reason: impl Into<String>,
203    ) -> Self {
204        KernelError::IncompatibleShapes {
205            operation: operation.into(),
206            shape_a,
207            shape_b,
208            reason: reason.into(),
209        }
210    }
211
212    /// Create a generic operation error
213    pub fn operation_error(operation: impl Into<String>, message: impl Into<String>) -> Self {
214        KernelError::OperationError {
215            operation: operation.into(),
216            message: message.into(),
217        }
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_dimension_mismatch_display() {
227        let err = KernelError::dimension_mismatch(
228            "khatri_rao",
229            vec![10, 5],
230            vec![10, 3],
231            "Number of columns must match",
232        );
233
234        let msg = format!("{}", err);
235        assert!(msg.contains("khatri_rao"));
236        assert!(msg.contains("dimension mismatch"));
237        assert!(msg.contains("[10, 5]"));
238        assert!(msg.contains("[10, 3]"));
239    }
240
241    #[test]
242    fn test_invalid_mode_display() {
243        let err = KernelError::invalid_mode(3, 3, "Tensor has only 3 modes");
244
245        let msg = format!("{}", err);
246        assert!(msg.contains("Invalid mode 3"));
247        assert!(msg.contains("must be < 3"));
248    }
249
250    #[test]
251    fn test_rank_mismatch_display() {
252        let err = KernelError::rank_mismatch("mttkrp", 5, 3, 2);
253
254        let msg = format!("{}", err);
255        assert!(msg.contains("mttkrp"));
256        assert!(msg.contains("factor 2"));
257        assert!(msg.contains("expected rank 5"));
258        assert!(msg.contains("got 3"));
259    }
260
261    #[test]
262    fn test_empty_input_display() {
263        let err = KernelError::empty_input("outer_product", "vectors");
264
265        let msg = format!("{}", err);
266        assert!(msg.contains("outer_product"));
267        assert!(msg.contains("empty input"));
268        assert!(msg.contains("vectors"));
269    }
270
271    #[test]
272    fn test_invalid_tile_size_display() {
273        let err = KernelError::invalid_tile_size("mttkrp_blocked", 0, "must be positive");
274
275        let msg = format!("{}", err);
276        assert!(msg.contains("mttkrp_blocked"));
277        assert!(msg.contains("invalid tile size 0"));
278        assert!(msg.contains("must be positive"));
279    }
280
281    #[test]
282    fn test_incompatible_shapes_display() {
283        let err = KernelError::incompatible_shapes(
284            "hadamard",
285            vec![2, 3],
286            vec![2, 4],
287            "Element-wise multiplication requires same shape",
288        );
289
290        let msg = format!("{}", err);
291        assert!(msg.contains("hadamard"));
292        assert!(msg.contains("[2, 3]"));
293        assert!(msg.contains("[2, 4]"));
294        assert!(msg.contains("Element-wise multiplication"));
295    }
296}