1use crate::core::Expression;
8use std::fmt;
9
10#[derive(Debug, Clone, PartialEq)]
15pub enum MathError {
16 DomainError {
23 operation: String,
24 value: Expression,
25 reason: String,
26 },
27
28 DivisionByZero,
35
36 Undefined {
43 expression: Expression,
44 reason: String,
45 },
46
47 NumericOverflow { operation: String },
49
50 NotImplemented { feature: String },
52
53 Pole { function: String, at: Expression },
61
62 BranchCut { function: String, value: Expression },
69
70 InvalidInterval { lower: f64, upper: f64 },
76
77 MaxIterationsReached { max_iterations: usize },
79
80 ConvergenceFailed { reason: String },
82
83 NonNumericalResult { expression: Expression },
90}
91
92impl fmt::Display for MathError {
93 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
94 match self {
95 MathError::DomainError {
96 operation,
97 value,
98 reason,
99 } => {
100 write!(f, "Domain error in {}: {} ({})", operation, value, reason)
101 }
102 MathError::DivisionByZero => {
103 write!(f, "Division by zero")
104 }
105 MathError::Undefined { expression, reason } => {
106 write!(f, "Undefined: {} ({})", expression, reason)
107 }
108 MathError::NumericOverflow { operation } => {
109 write!(f, "Numeric overflow in {}", operation)
110 }
111 MathError::NotImplemented { feature } => {
112 write!(f, "Not yet implemented: {}", feature)
113 }
114 MathError::Pole { function, at } => {
115 write!(f, "Pole singularity: {}({}) is undefined", function, at)
116 }
117 MathError::BranchCut { function, value } => {
118 write!(f, "Branch cut: {}({}) requires domain specification (use complex domain or specify branch)", function, value)
119 }
120 MathError::InvalidInterval { lower, upper } => {
121 write!(
122 f,
123 "Invalid interval: lower bound {} >= upper bound {}",
124 lower, upper
125 )
126 }
127 MathError::MaxIterationsReached { max_iterations } => {
128 write!(
129 f,
130 "Maximum iterations ({}) reached without convergence",
131 max_iterations
132 )
133 }
134 MathError::ConvergenceFailed { reason } => {
135 write!(f, "Convergence failed: {}", reason)
136 }
137 MathError::NonNumericalResult { expression } => {
138 write!(
139 f,
140 "Cannot convert non-numerical expression to number: {}",
141 expression
142 )
143 }
144 }
145 }
146}
147
148impl std::error::Error for MathError {}
149
150pub type MathResult<T> = Result<T, MathError>;
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use crate::expr;
157 use crate::Expression;
158
159 #[test]
160 fn test_error_display() {
161 let err = MathError::DivisionByZero;
162 assert_eq!(err.to_string(), "Division by zero");
163
164 let err = MathError::DomainError {
165 operation: "sqrt".to_string(),
166 value: Expression::integer(-1),
167 reason: "sqrt requires non-negative input in real domain".to_string(),
168 };
169 assert!(err.to_string().contains("Domain error in sqrt"));
170
171 let err = MathError::Pole {
172 function: "log".to_string(),
173 at: Expression::integer(0),
174 };
175 assert!(err.to_string().contains("Pole singularity"));
176 }
177
178 #[test]
179 fn test_error_equality() {
180 let err1 = MathError::DivisionByZero;
181 let err2 = MathError::DivisionByZero;
182 assert_eq!(err1, err2);
183
184 let err3 = MathError::NotImplemented {
185 feature: "groebner bases".to_string(),
186 };
187 assert_ne!(err1, err3);
188 }
189
190 #[test]
191 fn test_numerical_errors() {
192 let err = MathError::InvalidInterval {
193 lower: 1.0,
194 upper: 0.0,
195 };
196 assert!(err.to_string().contains("Invalid interval"));
197
198 let err = MathError::MaxIterationsReached {
199 max_iterations: 100,
200 };
201 assert!(err.to_string().contains("Maximum iterations"));
202
203 let err = MathError::ConvergenceFailed {
204 reason: "oscillating behavior".to_string(),
205 };
206 assert!(err.to_string().contains("Convergence failed"));
207 }
208
209 #[test]
210 fn test_non_numerical_result_error() {
211 let err = MathError::NonNumericalResult {
212 expression: expr!(x),
213 };
214 assert!(err
215 .to_string()
216 .contains("Cannot convert non-numerical expression to number"));
217 }
218}