mathhook_core/core/expression/
eval_numeric.rs

1//! Low-level numerical evaluation trait system
2//!
3//! This module implements the two-level evaluation architecture following SymPy's proven design:
4//!
5//! 1. **Low-level `EvalNumeric` trait:** Type-specific numerical evaluation without substitution
6//! 2. **High-level `evaluate()` method:** User-facing API with substitution, simplification, and control
7//!
8//! # Architecture
9//!
10//! The two-level design separates concerns:
11//!
12//! - `EvalNumeric::eval_numeric()`: Internal trait for numerical conversion (like SymPy's `_eval_evalf()`)
13//! - `Expression::evaluate_with_context()`: Public API with substitution and context control (like SymPy's `evalf()`)
14//!
15//! This separation enables:
16//! - Clear semantics for each expression type
17//! - Explicit control over evaluation behavior
18//! - Extensibility for custom types
19//!
20//! # Mathematical Background
21//!
22//! Numerical evaluation converts symbolic expressions to numerical form while preserving
23//! mathematical correctness. For example:
24//!
25//! - `sin(π/2)` → `1.0` (exact symbolic evaluation)
26//! - `sqrt(2)` → `1.4142135623730951` (numerical approximation with precision control)
27//! - `x^2` (with x=3) → `9` (after substitution and evaluation)
28use crate::core::number::Number;
29use crate::core::Expression;
30use crate::error::MathError;
31use num_bigint::BigInt;
32use num_rational::BigRational;
33use std::collections::HashMap;
34
35pub trait EvalNumeric {
36    /// Evaluate expression to numerical form
37    ///
38    /// # Arguments
39    ///
40    /// * `precision` - Number of bits of precision for numerical operations (default: 53 for f64)
41    ///
42    /// # Returns
43    ///
44    /// Expression in numerical form (may contain Number, Complex, Matrix of numbers, etc.)
45    ///
46    /// # Errors
47    ///
48    /// Returns `MathError` for:
49    /// - Domain violations (sqrt of negative, log of zero, etc.)
50    /// - Undefined operations (0/0, inf-inf, etc.)
51    /// - Numerical overflow/underflow
52    ///
53    /// # Implementation Requirements
54    ///
55    /// Implementations MUST:
56    /// 1. Handle domain restrictions correctly (return error for invalid inputs)
57    /// 2. Preserve mathematical correctness (exact evaluation when possible)
58    /// 3. Use specified precision for floating-point operations
59    /// 4. NOT perform variable substitution (that's `evaluate_with_context()`'s job)
60    fn eval_numeric(&self, precision: u32) -> Result<Expression, MathError>;
61}
62
63/// Evaluation context
64///
65/// Controls how `Expression::evaluate_with_context()` behaves. Provides variable substitutions,
66/// numerical evaluation control, and simplification options.
67///
68/// This mirrors SymPy's `evalf(subs={...}, ...)` high-level API.
69///
70/// # Two-Level Architecture
71///
72/// The context enables separation of concerns:
73///
74/// 1. **Variable substitution:** Replace symbols with values before evaluation
75/// 2. **Simplification control:** Optionally simplify symbolically first
76/// 3. **Numerical evaluation:** Convert to numerical form if requested
77///
78/// # Examples
79///
80/// ```rust
81/// use mathhook_core::{expr, symbol};
82/// use mathhook_core::core::expression::eval_numeric::EvalContext;
83/// use std::collections::HashMap;
84///
85/// // Symbolic evaluation (no numerical conversion)
86/// let ctx = EvalContext::symbolic();
87/// assert!(!ctx.numeric);
88/// assert!(ctx.variables.is_empty());
89///
90/// // Numerical evaluation with substitutions
91/// let mut vars = HashMap::new();
92/// vars.insert("x".to_string(), expr!(5));
93/// let ctx = EvalContext::numeric(vars);
94/// assert!(ctx.numeric);
95/// assert_eq!(ctx.variables.len(), 1);
96///
97/// // Custom precision
98/// let ctx = EvalContext::symbolic().with_precision(128);
99/// assert_eq!(ctx.precision, 128);
100/// ```
101#[derive(Debug, Clone)]
102pub struct EvalContext {
103    /// Variable substitutions (symbol name → value)
104    ///
105    /// Before evaluation, all symbols matching these names will be replaced
106    /// with the provided expressions. This enables parameterized evaluation.
107    ///
108    /// # Examples
109    ///
110    /// ```rust
111    /// use std::collections::HashMap;
112    /// use mathhook_core::{expr, Expression};
113    ///
114    /// let mut vars = HashMap::new();
115    /// vars.insert("x".to_string(), expr!(3));
116    /// vars.insert("y".to_string(), expr!(4));
117    /// // Now evaluating "x + y" will substitute → "3 + 4" → "7"
118    /// ```
119    pub variables: HashMap<String, Expression>,
120
121    /// Whether to perform numerical evaluation (evalf-style)
122    ///
123    /// - `true`: Convert to numerical form using `eval_numeric()`
124    /// - `false`: Keep symbolic form (only substitute variables)
125    pub numeric: bool,
126
127    /// Precision for numerical operations (bits)
128    ///
129    /// Controls accuracy of floating-point operations:
130    /// - 53 bits: f64 precision (default)
131    /// - 64 bits: Extended precision
132    /// - 128+ bits: Arbitrary precision (future)
133    ///
134    /// Note: Current implementation uses f64, so precision >53 has no effect yet.
135    /// Future versions will support arbitrary precision via `rug` or `mpc`.
136    pub precision: u32,
137
138    /// Whether to simplify symbolically before numerical evaluation
139    ///
140    /// - `true`: Call `simplify()` before `eval_numeric()` (recommended)
141    /// - `false`: Evaluate directly without simplification
142    ///
143    /// Simplification often improves numerical stability by reducing expression complexity.
144    pub simplify_first: bool,
145}
146
147impl EvalContext {
148    /// Create context for symbolic evaluation (no numerical conversion)
149    ///
150    /// Returns a context that performs variable substitution but keeps expressions
151    /// in symbolic form. No numerical evaluation is performed.
152    ///
153    /// # Returns
154    ///
155    /// Context with:
156    /// - No variable substitutions
157    /// - Symbolic mode (numeric = false)
158    /// - Default precision (53 bits)
159    /// - No pre-simplification
160    ///
161    /// # Examples
162    ///
163    /// ```rust
164    /// use mathhook_core::{expr, symbol};
165    /// use mathhook_core::core::expression::eval_numeric::EvalContext;
166    ///
167    /// let x = symbol!(x);
168    /// let e = expr!((x ^ 2) + (2*x) + 1);
169    ///
170    /// let ctx = EvalContext::symbolic();
171    /// let result = e.evaluate_with_context(&ctx).unwrap();
172    /// // Result is still symbolic: x^2 + 2*x + 1
173    /// ```
174    pub fn symbolic() -> Self {
175        Self {
176            variables: HashMap::new(),
177            numeric: false,
178            precision: 53,
179            simplify_first: false,
180        }
181    }
182
183    /// Create context for numerical evaluation with substitutions
184    ///
185    /// Returns a context that substitutes variables and converts to numerical form.
186    /// Simplification is enabled by default for numerical stability.
187    ///
188    /// # Arguments
189    ///
190    /// * `variables` - Map from symbol name to replacement expression
191    ///
192    /// # Returns
193    ///
194    /// Context with:
195    /// - Provided variable substitutions
196    /// - Numerical mode (numeric = true)
197    /// - Default precision (53 bits for f64)
198    /// - Pre-simplification enabled (simplify_first = true)
199    ///
200    /// # Examples
201    ///
202    /// ```rust
203    /// use mathhook_core::{expr, symbol};
204    /// use mathhook_core::core::expression::eval_numeric::EvalContext;
205    /// use std::collections::HashMap;
206    ///
207    /// let x = symbol!(x);
208    /// let e = expr!((x ^ 2) + (2*x) + 1);
209    ///
210    /// let mut vars = HashMap::new();
211    /// vars.insert("x".to_string(), expr!(3));
212    ///
213    /// let ctx = EvalContext::numeric(vars);
214    /// let result = e.evaluate_with_context(&ctx).unwrap();
215    /// // Result is numerical: 16 (= 3^2 + 2*3 + 1)
216    /// ```
217    pub fn numeric(variables: HashMap<String, Expression>) -> Self {
218        Self {
219            variables,
220            numeric: true,
221            precision: 53,
222            simplify_first: true,
223        }
224    }
225
226    /// Set precision for numerical operations (bits)
227    ///
228    /// Consumes self and returns a new context with the specified precision.
229    ///
230    /// # Arguments
231    ///
232    /// * `precision` - Number of bits of precision (53 for f64, 128+ for arbitrary precision)
233    ///
234    /// # Returns
235    ///
236    /// New context with updated precision
237    ///
238    /// # Examples
239    ///
240    /// ```rust
241    /// use mathhook_core::core::expression::eval_numeric::EvalContext;
242    ///
243    /// let ctx = EvalContext::symbolic().with_precision(128);
244    /// assert_eq!(ctx.precision, 128);
245    /// ```
246    pub fn with_precision(mut self, precision: u32) -> Self {
247        self.precision = precision;
248        self
249    }
250
251    /// Control whether to simplify symbolically before numerical evaluation
252    ///
253    /// Consumes self and returns a new context with the specified simplification flag.
254    ///
255    /// # Arguments
256    ///
257    /// * `simplify` - Whether to call `simplify()` before `eval_numeric()`
258    ///
259    /// # Returns
260    ///
261    /// New context with updated simplification setting
262    ///
263    /// # Examples
264    ///
265    /// ```rust
266    /// use mathhook_core::core::expression::eval_numeric::EvalContext;
267    ///
268    /// // Disable simplification for performance
269    /// let ctx = EvalContext::symbolic().with_simplify(false);
270    /// assert!(!ctx.simplify_first);
271    ///
272    /// // Enable simplification for numerical stability
273    /// let ctx = EvalContext::symbolic().with_simplify(true);
274    /// assert!(ctx.simplify_first);
275    /// ```
276    pub fn with_simplify(mut self, simplify: bool) -> Self {
277        self.simplify_first = simplify;
278        self
279    }
280}
281
282impl Default for EvalContext {
283    /// Default context is symbolic (no numerical evaluation)
284    ///
285    /// # Examples
286    ///
287    /// ```rust
288    /// use mathhook_core::core::expression::eval_numeric::EvalContext;
289    ///
290    /// let ctx = EvalContext::default();
291    /// assert!(!ctx.numeric);
292    /// assert!(ctx.variables.is_empty());
293    /// ```
294    fn default() -> Self {
295        Self::symbolic()
296    }
297}
298
299fn is_number_negative(n: &Number) -> bool {
300    match n {
301        Number::Integer(i) => *i < 0,
302        Number::Float(f) => *f < 0.0,
303        Number::BigInteger(bi) => **bi < BigInt::from(0),
304        Number::Rational(r) => **r < BigRational::new(BigInt::from(0), BigInt::from(1)),
305    }
306}
307
308impl EvalNumeric for Expression {
309    fn eval_numeric(&self, _precision: u32) -> Result<Expression, MathError> {
310        match self {
311            Expression::Number(_) => Ok(self.clone()),
312            Expression::Symbol(_) => Ok(self.clone()),
313
314            Expression::Constant(c) => {
315                use crate::core::MathConstant;
316                match c {
317                    MathConstant::Pi => Ok(Expression::float(std::f64::consts::PI)),
318                    MathConstant::E => Ok(Expression::float(std::f64::consts::E)),
319                    MathConstant::I => Ok(self.clone()),
320                    MathConstant::Infinity => Ok(self.clone()),
321                    MathConstant::NegativeInfinity => Ok(self.clone()),
322                    MathConstant::Undefined => Ok(self.clone()),
323                    MathConstant::GoldenRatio => {
324                        Ok(Expression::float(MathConstant::GoldenRatio.to_f64()))
325                    }
326                    MathConstant::EulerGamma => {
327                        Ok(Expression::float(MathConstant::EulerGamma.to_f64()))
328                    }
329                    MathConstant::TribonacciConstant => {
330                        Ok(Expression::float(MathConstant::TribonacciConstant.to_f64()))
331                    }
332                }
333            }
334
335            Expression::Add(terms) => {
336                let evaluated: Result<Vec<_>, _> =
337                    terms.iter().map(|t| t.eval_numeric(_precision)).collect();
338                Ok(Expression::add(evaluated?))
339            }
340
341            Expression::Mul(factors) => {
342                let evaluated: Result<Vec<_>, _> =
343                    factors.iter().map(|f| f.eval_numeric(_precision)).collect();
344                Ok(Expression::mul(evaluated?))
345            }
346
347            Expression::Pow(base, exp) => {
348                let base_eval = base.eval_numeric(_precision)?;
349                let exp_eval = exp.eval_numeric(_precision)?;
350
351                if base_eval.is_zero() {
352                    if let Expression::Number(n) = &exp_eval {
353                        if is_number_negative(n) {
354                            return Err(MathError::DivisionByZero);
355                        }
356                    }
357                }
358
359                Ok(Expression::pow(base_eval, exp_eval))
360            }
361
362            Expression::Function { name, args } => {
363                let eval_args = args
364                    .iter()
365                    .map(|arg| arg.eval_numeric(_precision))
366                    .collect::<Result<Vec<_>, _>>()?;
367
368                if let Some(result) =
369                    super::evaluation::evaluate_function_dispatch(name, &eval_args)
370                {
371                    return Ok(result);
372                }
373
374                Ok(Expression::function(name.clone(), eval_args))
375            }
376
377            Expression::Matrix(matrix) => {
378                let (rows, cols) = matrix.dimensions();
379                let mut new_rows = Vec::with_capacity(rows);
380
381                for i in 0..rows {
382                    let mut row = Vec::with_capacity(cols);
383                    for j in 0..cols {
384                        let element = matrix.get_element(i, j);
385                        row.push(element.eval_numeric(_precision)?);
386                    }
387                    new_rows.push(row);
388                }
389
390                Ok(Expression::matrix(new_rows))
391            }
392
393            Expression::Set(elements) => {
394                let evaluated: Result<Vec<_>, _> = elements
395                    .iter()
396                    .map(|e| e.eval_numeric(_precision))
397                    .collect();
398                Ok(Expression::set(evaluated?))
399            }
400
401            Expression::Complex(data) => {
402                let real_eval = data.real.eval_numeric(_precision)?;
403                let imag_eval = data.imag.eval_numeric(_precision)?;
404                Ok(Expression::complex(real_eval, imag_eval))
405            }
406
407            Expression::Interval(interval) => {
408                let start_eval = interval.start.eval_numeric(_precision)?;
409                let end_eval = interval.end.eval_numeric(_precision)?;
410
411                Ok(Expression::interval(
412                    start_eval,
413                    end_eval,
414                    interval.start_inclusive,
415                    interval.end_inclusive,
416                ))
417            }
418
419            Expression::Piecewise(data) => {
420                let mut new_pieces = Vec::with_capacity(data.pieces.len());
421
422                for (expr, cond) in &data.pieces {
423                    let expr_eval = expr.eval_numeric(_precision)?;
424                    new_pieces.push((expr_eval, cond.clone()));
425                }
426
427                let default_eval = if let Some(ref default) = data.default {
428                    Some(default.eval_numeric(_precision)?)
429                } else {
430                    None
431                };
432
433                Ok(Expression::piecewise(new_pieces, default_eval))
434            }
435
436            Expression::Relation(rel) => {
437                let lhs_eval = rel.left.eval_numeric(_precision)?;
438                let rhs_eval = rel.right.eval_numeric(_precision)?;
439
440                Ok(Expression::relation(lhs_eval, rhs_eval, rel.relation_type))
441            }
442
443            Expression::Calculus(_) => Ok(self.clone()),
444
445            Expression::MethodCall(_) => Ok(self.clone()),
446        }
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    #[test]
455    fn test_eval_context_symbolic() {
456        let ctx = EvalContext::symbolic();
457        assert!(!ctx.numeric);
458        assert!(ctx.variables.is_empty());
459        assert_eq!(ctx.precision, 53);
460        assert!(!ctx.simplify_first);
461    }
462
463    #[test]
464    fn test_eval_context_numeric() {
465        let mut vars = HashMap::new();
466        vars.insert("x".to_string(), Expression::integer(5));
467        let ctx = EvalContext::numeric(vars);
468
469        assert!(ctx.numeric);
470        assert_eq!(ctx.variables.len(), 1);
471        assert_eq!(ctx.precision, 53);
472        assert!(ctx.simplify_first);
473    }
474
475    #[test]
476    fn test_eval_context_with_precision() {
477        let ctx = EvalContext::symbolic().with_precision(128);
478        assert_eq!(ctx.precision, 128);
479    }
480
481    #[test]
482    fn test_eval_context_with_simplify() {
483        let ctx = EvalContext::symbolic().with_simplify(true);
484        assert!(ctx.simplify_first);
485
486        let ctx = EvalContext::symbolic().with_simplify(false);
487        assert!(!ctx.simplify_first);
488    }
489
490    #[test]
491    fn test_eval_context_default() {
492        let ctx = EvalContext::default();
493        assert!(!ctx.numeric);
494        assert!(ctx.variables.is_empty());
495    }
496
497    #[test]
498    fn test_eval_context_chaining() {
499        let mut vars = HashMap::new();
500        vars.insert("x".to_string(), Expression::integer(3));
501
502        let ctx = EvalContext::numeric(vars)
503            .with_precision(128)
504            .with_simplify(false);
505
506        assert!(ctx.numeric);
507        assert_eq!(ctx.variables.len(), 1);
508        assert_eq!(ctx.precision, 128);
509        assert!(!ctx.simplify_first);
510    }
511}