axonml_jit/
error.rs

1//! JIT Error Types
2//!
3//! Error handling for JIT compilation operations.
4
5use std::fmt;
6
7/// Result type for JIT operations.
8pub type JitResult<T> = Result<T, JitError>;
9
10/// JIT compilation errors.
11#[derive(Debug, Clone)]
12pub enum JitError {
13    /// Invalid graph structure.
14    InvalidGraph(String),
15    /// Type mismatch in operations.
16    TypeMismatch {
17        /// Expected type.
18        expected: String,
19        /// Actual type.
20        found: String,
21    },
22    /// Shape mismatch in operations.
23    ShapeMismatch {
24        /// Expected shape.
25        expected: Vec<usize>,
26        /// Actual shape.
27        found: Vec<usize>,
28    },
29    /// Unsupported operation for JIT.
30    UnsupportedOp(String),
31    /// Code generation failed.
32    CodegenError(String),
33    /// Runtime execution error.
34    RuntimeError(String),
35    /// Input not found.
36    InputNotFound(String),
37    /// Output not found.
38    OutputNotFound(String),
39    /// Compilation failed.
40    CompilationFailed(String),
41}
42
43impl fmt::Display for JitError {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        match self {
46            Self::InvalidGraph(msg) => write!(f, "Invalid graph: {msg}"),
47            Self::TypeMismatch { expected, found } => {
48                write!(f, "Type mismatch: expected {expected}, found {found}")
49            }
50            Self::ShapeMismatch { expected, found } => {
51                write!(f, "Shape mismatch: expected {expected:?}, found {found:?}")
52            }
53            Self::UnsupportedOp(op) => write!(f, "Unsupported operation: {op}"),
54            Self::CodegenError(msg) => write!(f, "Code generation error: {msg}"),
55            Self::RuntimeError(msg) => write!(f, "Runtime error: {msg}"),
56            Self::InputNotFound(name) => write!(f, "Input not found: {name}"),
57            Self::OutputNotFound(name) => write!(f, "Output not found: {name}"),
58            Self::CompilationFailed(msg) => write!(f, "Compilation failed: {msg}"),
59        }
60    }
61}
62
63impl std::error::Error for JitError {}
64
65impl From<String> for JitError {
66    fn from(msg: String) -> Self {
67        Self::RuntimeError(msg)
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74
75    #[test]
76    fn test_error_display() {
77        let err = JitError::TypeMismatch {
78            expected: "f32".to_string(),
79            found: "i32".to_string(),
80        };
81        assert!(err.to_string().contains("Type mismatch"));
82    }
83
84    #[test]
85    fn test_shape_mismatch() {
86        let err = JitError::ShapeMismatch {
87            expected: vec![2, 3],
88            found: vec![3, 2],
89        };
90        assert!(err.to_string().contains("Shape mismatch"));
91    }
92}