mathhook_core/core/expression/evaluation/
substitution.rs

1//! Variable substitution for expressions
2//!
3//! Contains the `substitute()` method for replacing symbols with expressions.
4
5use super::super::Expression;
6use crate::simplify::Simplify;
7use std::collections::HashMap;
8use std::sync::Arc;
9
10impl Expression {
11    /// Substitute variables with expressions
12    ///
13    /// Recursively replaces all occurrences of symbols with provided expressions.
14    ///
15    /// # Arguments
16    ///
17    /// * `substitutions` - Map from symbol name to replacement expression
18    ///
19    /// # Returns
20    ///
21    /// New expression with substitutions applied
22    ///
23    /// # Examples
24    ///
25    /// ```rust,ignore
26    /// use mathhook_core::{expr, symbol};
27    /// use std::collections::HashMap;
28    ///
29    /// let x = symbol!(x);
30    /// let y = symbol!(y);
31    /// let e = expr!(x + y);
32    ///
33    /// let mut subs = HashMap::new();
34    /// subs.insert("x".to_string(), expr!(3));
35    /// subs.insert("y".to_string(), expr!(4));
36    ///
37    /// let result = e.substitute(&subs);
38    /// assert_eq!(result, expr!(3 + 4));
39    /// ```
40    pub fn substitute(&self, substitutions: &HashMap<String, Expression>) -> Expression {
41        match self {
42            Expression::Number(_) | Expression::Constant(_) => self.clone(),
43
44            Expression::Symbol(sym) => substitutions
45                .get(sym.name())
46                .cloned()
47                .unwrap_or_else(|| self.clone()),
48
49            Expression::Add(terms) => {
50                let new_terms: Vec<Expression> =
51                    terms.iter().map(|t| t.substitute(substitutions)).collect();
52                Expression::add(new_terms)
53            }
54
55            Expression::Mul(factors) => {
56                let new_factors: Vec<Expression> = factors
57                    .iter()
58                    .map(|f| f.substitute(substitutions))
59                    .collect();
60                Expression::mul(new_factors)
61            }
62
63            Expression::Pow(base, exp) => {
64                let new_base = base.substitute(substitutions);
65                let new_exp = exp.substitute(substitutions);
66                Expression::pow(new_base, new_exp)
67            }
68
69            Expression::Function { name, args } => {
70                let new_args: Vec<Expression> = args
71                    .iter()
72                    .map(|arg| arg.substitute(substitutions))
73                    .collect();
74                Expression::function(name.clone(), new_args)
75            }
76
77            Expression::Set(elements) => {
78                let new_elements: Vec<Expression> = elements
79                    .iter()
80                    .map(|e| e.substitute(substitutions))
81                    .collect();
82                Expression::set(new_elements)
83            }
84
85            Expression::Complex(data) => {
86                let new_real = data.real.substitute(substitutions);
87                let new_imag = data.imag.substitute(substitutions);
88                Expression::complex(new_real, new_imag)
89            }
90
91            Expression::Relation(data) => {
92                let new_left = data.left.substitute(substitutions);
93                let new_right = data.right.substitute(substitutions);
94                Expression::relation(new_left, new_right, data.relation_type)
95            }
96
97            Expression::Piecewise(data) => {
98                let new_pieces: Vec<(Expression, Expression)> = data
99                    .pieces
100                    .iter()
101                    .map(|(expr, cond)| {
102                        (
103                            expr.substitute(substitutions),
104                            cond.substitute(substitutions),
105                        )
106                    })
107                    .collect();
108                let new_default = data.default.as_ref().map(|d| d.substitute(substitutions));
109                Expression::piecewise(new_pieces, new_default)
110            }
111
112            Expression::Interval(data) => {
113                let new_start = data.start.substitute(substitutions);
114                let new_end = data.end.substitute(substitutions);
115                Expression::interval(new_start, new_end, data.start_inclusive, data.end_inclusive)
116            }
117
118            Expression::Calculus(data) => {
119                use crate::core::expression::CalculusData;
120                let new_data = match data.as_ref() {
121                    CalculusData::Derivative {
122                        expression,
123                        variable,
124                        order,
125                    } => CalculusData::Derivative {
126                        expression: expression.substitute(substitutions),
127                        variable: variable.clone(),
128                        order: *order,
129                    },
130                    CalculusData::Integral {
131                        integrand,
132                        variable,
133                        bounds,
134                    } => CalculusData::Integral {
135                        integrand: integrand.substitute(substitutions),
136                        variable: variable.clone(),
137                        bounds: bounds.as_ref().map(|(lower, upper)| {
138                            (
139                                lower.substitute(substitutions),
140                                upper.substitute(substitutions),
141                            )
142                        }),
143                    },
144                    CalculusData::Limit {
145                        expression,
146                        variable,
147                        point,
148                        direction,
149                    } => CalculusData::Limit {
150                        expression: expression.substitute(substitutions),
151                        variable: variable.clone(),
152                        point: point.substitute(substitutions),
153                        direction: *direction,
154                    },
155                    CalculusData::Sum {
156                        expression,
157                        variable,
158                        start,
159                        end,
160                    } => CalculusData::Sum {
161                        expression: expression.substitute(substitutions),
162                        variable: variable.clone(),
163                        start: start.substitute(substitutions),
164                        end: end.substitute(substitutions),
165                    },
166                    CalculusData::Product {
167                        expression,
168                        variable,
169                        start,
170                        end,
171                    } => CalculusData::Product {
172                        expression: expression.substitute(substitutions),
173                        variable: variable.clone(),
174                        start: start.substitute(substitutions),
175                        end: end.substitute(substitutions),
176                    },
177                };
178                Expression::Calculus(Arc::new(new_data))
179            }
180
181            Expression::MethodCall(data) => {
182                let new_object = data.object.substitute(substitutions);
183                let new_args: Vec<Expression> = data
184                    .args
185                    .iter()
186                    .map(|arg| arg.substitute(substitutions))
187                    .collect();
188                Expression::method_call(new_object, data.method_name.clone(), new_args)
189            }
190
191            Expression::Matrix(_) => self.clone(),
192        }
193    }
194
195    /// Substitute and simplify in one step
196    ///
197    /// Convenience method that applies substitutions and then simplifies the result.
198    ///
199    /// # Arguments
200    ///
201    /// * `substitutions` - Map from symbol name to replacement expression
202    ///
203    /// # Returns
204    ///
205    /// New simplified expression with substitutions applied
206    pub fn substitute_and_simplify(
207        &self,
208        substitutions: &HashMap<String, Expression>,
209    ) -> Expression {
210        self.substitute(substitutions).simplify()
211    }
212}