1use std::fmt;
4
5#[derive(Debug, Clone, PartialEq)]
7pub enum AutodiffError {
8 ArityError {
10 operation: &'static str,
12 expected: usize,
14 actual: usize,
16 },
17 EmptyGraph,
19 IndexOutOfBounds {
21 index: usize,
23 max_index: usize,
25 },
26}
27
28impl fmt::Display for AutodiffError {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 match self {
31 AutodiffError::ArityError {
32 operation,
33 expected,
34 actual,
35 } => write!(
36 f,
37 "Arity error in {}: expected {}, got {}",
38 operation, expected, actual
39 ),
40 AutodiffError::EmptyGraph => write!(f, "Computation graph is empty"),
41 AutodiffError::IndexOutOfBounds { index, max_index } => {
42 write!(f, "Index {} is out of bounds (max: {})", index, max_index)
43 }
44 }
45 }
46}
47
48impl std::error::Error for AutodiffError {}
49
50impl AutodiffError {
51 pub fn arity(operation: &'static str, expected: usize, actual: usize) -> Self {
53 AutodiffError::ArityError {
54 operation,
55 expected,
56 actual,
57 }
58 }
59
60 pub fn check_arity(
62 operation: &'static str,
63 expected: usize,
64 actual: usize,
65 ) -> std::result::Result<(), AutodiffError> {
66 if actual == expected {
67 Ok(())
68 } else {
69 Err(AutodiffError::arity(operation, expected, actual))
70 }
71 }
72}
73
74pub type Result<T> = std::result::Result<T, AutodiffError>;