mathhook_core/core/expression/evaluation/
substitution.rs

1//! Variable substitution for expressions
2//!
3//! Contains the `substitute()` method for replacing symbols with expressions.
4
5use super::super::Expression;
6use crate::simplify::Simplify;
7use std::collections::HashMap;
8
9impl Expression {
10    /// Substitute variables with expressions
11    ///
12    /// Recursively replaces all occurrences of symbols with provided expressions.
13    ///
14    /// # Arguments
15    ///
16    /// * `substitutions` - Map from symbol name to replacement expression
17    ///
18    /// # Returns
19    ///
20    /// New expression with substitutions applied
21    ///
22    /// # Examples
23    ///
24    /// ```rust,ignore
25    /// use mathhook_core::{expr, symbol};
26    /// use std::collections::HashMap;
27    ///
28    /// let x = symbol!(x);
29    /// let y = symbol!(y);
30    /// let e = expr!(x + y);
31    ///
32    /// let mut subs = HashMap::new();
33    /// subs.insert("x".to_string(), expr!(3));
34    /// subs.insert("y".to_string(), expr!(4));
35    ///
36    /// let result = e.substitute(&subs);
37    /// assert_eq!(result, expr!(3 + 4));
38    /// ```
39    pub fn substitute(&self, substitutions: &HashMap<String, Expression>) -> Expression {
40        match self {
41            Expression::Number(_) | Expression::Constant(_) => self.clone(),
42
43            Expression::Symbol(sym) => substitutions
44                .get(sym.name())
45                .cloned()
46                .unwrap_or_else(|| self.clone()),
47
48            Expression::Add(terms) => {
49                let new_terms: Vec<Expression> =
50                    terms.iter().map(|t| t.substitute(substitutions)).collect();
51                Expression::add(new_terms)
52            }
53
54            Expression::Mul(factors) => {
55                let new_factors: Vec<Expression> = factors
56                    .iter()
57                    .map(|f| f.substitute(substitutions))
58                    .collect();
59                Expression::mul(new_factors)
60            }
61
62            Expression::Pow(base, exp) => {
63                let new_base = base.substitute(substitutions);
64                let new_exp = exp.substitute(substitutions);
65                Expression::pow(new_base, new_exp)
66            }
67
68            Expression::Function { name, args } => {
69                let new_args: Vec<Expression> = args
70                    .iter()
71                    .map(|arg| arg.substitute(substitutions))
72                    .collect();
73                Expression::function(name.clone(), new_args)
74            }
75
76            Expression::Set(elements) => {
77                let new_elements: Vec<Expression> = elements
78                    .iter()
79                    .map(|e| e.substitute(substitutions))
80                    .collect();
81                Expression::set(new_elements)
82            }
83
84            Expression::Complex(data) => {
85                let new_real = data.real.substitute(substitutions);
86                let new_imag = data.imag.substitute(substitutions);
87                Expression::complex(new_real, new_imag)
88            }
89
90            Expression::Relation(data) => {
91                let new_left = data.left.substitute(substitutions);
92                let new_right = data.right.substitute(substitutions);
93                Expression::relation(new_left, new_right, data.relation_type)
94            }
95
96            Expression::Piecewise(data) => {
97                let new_pieces: Vec<(Expression, Expression)> = data
98                    .pieces
99                    .iter()
100                    .map(|(expr, cond)| {
101                        (
102                            expr.substitute(substitutions),
103                            cond.substitute(substitutions),
104                        )
105                    })
106                    .collect();
107                let new_default = data.default.as_ref().map(|d| d.substitute(substitutions));
108                Expression::piecewise(new_pieces, new_default)
109            }
110
111            Expression::Interval(data) => {
112                let new_start = data.start.substitute(substitutions);
113                let new_end = data.end.substitute(substitutions);
114                Expression::interval(new_start, new_end, data.start_inclusive, data.end_inclusive)
115            }
116
117            Expression::Calculus(data) => {
118                use crate::core::expression::CalculusData;
119                let new_data = match data.as_ref() {
120                    CalculusData::Derivative {
121                        expression,
122                        variable,
123                        order,
124                    } => CalculusData::Derivative {
125                        expression: expression.substitute(substitutions),
126                        variable: variable.clone(),
127                        order: *order,
128                    },
129                    CalculusData::Integral {
130                        integrand,
131                        variable,
132                        bounds,
133                    } => CalculusData::Integral {
134                        integrand: integrand.substitute(substitutions),
135                        variable: variable.clone(),
136                        bounds: bounds.as_ref().map(|(lower, upper)| {
137                            (
138                                lower.substitute(substitutions),
139                                upper.substitute(substitutions),
140                            )
141                        }),
142                    },
143                    CalculusData::Limit {
144                        expression,
145                        variable,
146                        point,
147                        direction,
148                    } => CalculusData::Limit {
149                        expression: expression.substitute(substitutions),
150                        variable: variable.clone(),
151                        point: point.substitute(substitutions),
152                        direction: *direction,
153                    },
154                    CalculusData::Sum {
155                        expression,
156                        variable,
157                        start,
158                        end,
159                    } => CalculusData::Sum {
160                        expression: expression.substitute(substitutions),
161                        variable: variable.clone(),
162                        start: start.substitute(substitutions),
163                        end: end.substitute(substitutions),
164                    },
165                    CalculusData::Product {
166                        expression,
167                        variable,
168                        start,
169                        end,
170                    } => CalculusData::Product {
171                        expression: expression.substitute(substitutions),
172                        variable: variable.clone(),
173                        start: start.substitute(substitutions),
174                        end: end.substitute(substitutions),
175                    },
176                };
177                Expression::Calculus(Box::new(new_data))
178            }
179
180            Expression::MethodCall(data) => {
181                let new_object = data.object.substitute(substitutions);
182                let new_args: Vec<Expression> = data
183                    .args
184                    .iter()
185                    .map(|arg| arg.substitute(substitutions))
186                    .collect();
187                Expression::method_call(new_object, data.method_name.clone(), new_args)
188            }
189
190            Expression::Matrix(_) => self.clone(),
191        }
192    }
193
194    /// Substitute and simplify in one step
195    ///
196    /// Convenience method that applies substitutions and then simplifies the result.
197    ///
198    /// # Arguments
199    ///
200    /// * `substitutions` - Map from symbol name to replacement expression
201    ///
202    /// # Returns
203    ///
204    /// New simplified expression with substitutions applied
205    pub fn substitute_and_simplify(
206        &self,
207        substitutions: &HashMap<String, Expression>,
208    ) -> Expression {
209        self.substitute(substitutions).simplify()
210    }
211}