1use std::fmt;
6
7pub type JitResult<T> = Result<T, JitError>;
9
10#[derive(Debug, Clone)]
12pub enum JitError {
13 InvalidGraph(String),
15 TypeMismatch {
17 expected: String,
19 found: String,
21 },
22 ShapeMismatch {
24 expected: Vec<usize>,
26 found: Vec<usize>,
28 },
29 UnsupportedOp(String),
31 CodegenError(String),
33 RuntimeError(String),
35 InputNotFound(String),
37 OutputNotFound(String),
39 CompilationFailed(String),
41}
42
43impl fmt::Display for JitError {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 match self {
46 Self::InvalidGraph(msg) => write!(f, "Invalid graph: {msg}"),
47 Self::TypeMismatch { expected, found } => {
48 write!(f, "Type mismatch: expected {expected}, found {found}")
49 }
50 Self::ShapeMismatch { expected, found } => {
51 write!(f, "Shape mismatch: expected {expected:?}, found {found:?}")
52 }
53 Self::UnsupportedOp(op) => write!(f, "Unsupported operation: {op}"),
54 Self::CodegenError(msg) => write!(f, "Code generation error: {msg}"),
55 Self::RuntimeError(msg) => write!(f, "Runtime error: {msg}"),
56 Self::InputNotFound(name) => write!(f, "Input not found: {name}"),
57 Self::OutputNotFound(name) => write!(f, "Output not found: {name}"),
58 Self::CompilationFailed(msg) => write!(f, "Compilation failed: {msg}"),
59 }
60 }
61}
62
63impl std::error::Error for JitError {}
64
65impl From<String> for JitError {
66 fn from(msg: String) -> Self {
67 Self::RuntimeError(msg)
68 }
69}
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74
75 #[test]
76 fn test_error_display() {
77 let err = JitError::TypeMismatch {
78 expected: "f32".to_string(),
79 found: "i32".to_string(),
80 };
81 assert!(err.to_string().contains("Type mismatch"));
82 }
83
84 #[test]
85 fn test_shape_mismatch() {
86 let err = JitError::ShapeMismatch {
87 expected: vec![2, 3],
88 found: vec![3, 2],
89 };
90 assert!(err.to_string().contains("Shape mismatch"));
91 }
92}