petite_ad/
error.rs

1//! Error types for automatic differentiation operations.
2
3use std::fmt;
4
5/// Errors that can occur during automatic differentiation computations.
6#[derive(Debug, Clone, PartialEq)]
7pub enum AutodiffError {
8    /// An operation received an incorrect number of arguments (specific arity error).
9    ArityError {
10        /// Name of the operation
11        operation: &'static str,
12        /// Expected number of arguments
13        expected: usize,
14        /// Actual number of arguments received
15        actual: usize,
16    },
17    /// The computation graph is empty or invalid.
18    EmptyGraph,
19    /// An index references a non-existent value in the computation.
20    IndexOutOfBounds {
21        /// The invalid index
22        index: usize,
23        /// The maximum valid index
24        max_index: usize,
25    },
26}
27
28impl fmt::Display for AutodiffError {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        match self {
31            AutodiffError::ArityError {
32                operation,
33                expected,
34                actual,
35            } => write!(
36                f,
37                "Arity error in {}: expected {}, got {}",
38                operation, expected, actual
39            ),
40            AutodiffError::EmptyGraph => write!(f, "Computation graph is empty"),
41            AutodiffError::IndexOutOfBounds { index, max_index } => {
42                write!(f, "Index {} is out of bounds (max: {})", index, max_index)
43            }
44        }
45    }
46}
47
48impl std::error::Error for AutodiffError {}
49
50impl AutodiffError {
51    /// Create an ArityError for an operation with incorrect argument count.
52    pub fn arity(operation: &'static str, expected: usize, actual: usize) -> Self {
53        AutodiffError::ArityError {
54            operation,
55            expected,
56            actual,
57        }
58    }
59
60    /// Validate that an operation received the correct number of arguments.
61    pub fn check_arity(
62        operation: &'static str,
63        expected: usize,
64        actual: usize,
65    ) -> std::result::Result<(), AutodiffError> {
66        if actual == expected {
67            Ok(())
68        } else {
69            Err(AutodiffError::arity(operation, expected, actual))
70        }
71    }
72}
73
74/// Result type for automatic differentiation operations.
75pub type Result<T> = std::result::Result<T, AutodiffError>;