mathhook_core/pattern/substitution/
core.rs

1//! Core substitution trait and single-expression substitution implementation
2
3use crate::core::Expression;
4use crate::simplify::Simplify;
5
6/// Trait for types that support substitution operations
7pub trait Substitutable {
8    /// Substitute a single expression with another
9    ///
10    /// Recursively walks the expression tree and replaces all occurrences
11    /// of `old` with `new`. The replacement is structural - it compares
12    /// expressions using PartialEq.
13    ///
14    /// # Arguments
15    ///
16    /// * `old` - The expression to replace
17    /// * `new` - The expression to substitute in
18    ///
19    /// # Examples
20    ///
21    /// ```
22    /// use mathhook_core::prelude::*;
23    /// use mathhook_core::pattern::Substitutable;
24    ///
25    /// let x = symbol!(x);
26    /// let expr = Expression::add(vec![
27    ///     Expression::symbol(x.clone()),
28    ///     Expression::integer(1)
29    /// ]);
30    ///
31    /// let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(5));
32    /// let expected = Expression::add(vec![
33    ///     Expression::integer(5),
34    ///     Expression::integer(1)
35    /// ]);
36    ///
37    /// assert_eq!(result, expected);
38    /// ```
39    fn subs(&self, old: &Expression, new: &Expression) -> Expression;
40
41    /// Apply multiple substitutions simultaneously
42    ///
43    /// This is more efficient than chaining multiple `subs()` calls because
44    /// it performs all substitutions in a single tree traversal.
45    ///
46    /// # Arguments
47    ///
48    /// * `substitutions` - Slice of (old, new) expression pairs
49    ///
50    /// # Examples
51    ///
52    /// ```
53    /// use mathhook_core::prelude::*;
54    /// use mathhook_core::pattern::Substitutable;
55    ///
56    /// let x = symbol!(x);
57    /// let y = symbol!(y);
58    /// let expr = Expression::add(vec![
59    ///     Expression::symbol(x.clone()),
60    ///     Expression::symbol(y.clone())
61    /// ]);
62    ///
63    /// let result = expr.subs_multiple(&[
64    ///     (Expression::symbol(x.clone()), Expression::integer(1)),
65    ///     (Expression::symbol(y.clone()), Expression::integer(2)),
66    /// ]);
67    ///
68    /// let expected = Expression::add(vec![
69    ///     Expression::integer(1),
70    ///     Expression::integer(2)
71    /// ]);
72    ///
73    /// assert_eq!(result, expected);
74    /// ```
75    fn subs_multiple(&self, substitutions: &[(Expression, Expression)]) -> Expression;
76}
77
78impl Substitutable for Expression {
79    fn subs(&self, old: &Expression, new: &Expression) -> Expression {
80        if self == old {
81            return new.clone();
82        }
83
84        let result = match self {
85            Expression::Number(_) | Expression::Constant(_) => self.clone(),
86
87            Expression::Symbol(_) => self.clone(),
88
89            Expression::Add(terms) => {
90                let new_terms: Vec<Expression> = terms.iter().map(|t| t.subs(old, new)).collect();
91                Expression::Add(Box::new(new_terms))
92            }
93
94            Expression::Mul(factors) => {
95                let new_factors: Vec<Expression> =
96                    factors.iter().map(|f| f.subs(old, new)).collect();
97                Expression::mul(new_factors)
98            }
99
100            Expression::Pow(base, exp) => {
101                let new_base = base.subs(old, new);
102                let new_exp = exp.subs(old, new);
103                Expression::Pow(Box::new(new_base), Box::new(new_exp))
104            }
105
106            Expression::Function { name, args } => {
107                let new_args: Vec<Expression> = args.iter().map(|a| a.subs(old, new)).collect();
108                Expression::Function {
109                    name: name.clone(),
110                    args: Box::new(new_args),
111                }
112            }
113
114            Expression::Set(elements) => {
115                let new_elements: Vec<Expression> =
116                    elements.iter().map(|e| e.subs(old, new)).collect();
117                Expression::Set(Box::new(new_elements))
118            }
119
120            Expression::Complex(data) => {
121                let new_real = data.real.subs(old, new);
122                let new_imag = data.imag.subs(old, new);
123                Expression::Complex(Box::new(crate::core::expression::ComplexData {
124                    real: new_real,
125                    imag: new_imag,
126                }))
127            }
128
129            Expression::Matrix(matrix) => {
130                let (rows, cols) = matrix.dimensions();
131                let mut new_data: Vec<Vec<Expression>> = Vec::with_capacity(rows);
132
133                for i in 0..rows {
134                    let mut row: Vec<Expression> = Vec::with_capacity(cols);
135                    for j in 0..cols {
136                        let elem = matrix.get_element(i, j);
137                        row.push(elem.subs(old, new));
138                    }
139                    new_data.push(row);
140                }
141
142                Expression::Matrix(Box::new(crate::matrices::unified::Matrix::dense(new_data)))
143            }
144
145            Expression::Relation(data) => {
146                let new_left = data.left.subs(old, new);
147                let new_right = data.right.subs(old, new);
148                Expression::Relation(Box::new(crate::core::expression::RelationData {
149                    left: new_left,
150                    right: new_right,
151                    relation_type: data.relation_type,
152                }))
153            }
154
155            Expression::Piecewise(data) => {
156                let new_pieces: Vec<(Expression, Expression)> = data
157                    .pieces
158                    .iter()
159                    .map(|(expr, cond)| (expr.subs(old, new), cond.subs(old, new)))
160                    .collect();
161
162                let new_default = data.default.as_ref().map(|d| d.subs(old, new));
163
164                Expression::Piecewise(Box::new(crate::core::expression::PiecewiseData {
165                    pieces: new_pieces,
166                    default: new_default,
167                }))
168            }
169
170            Expression::Interval(data) => {
171                let new_start = data.start.subs(old, new);
172                let new_end = data.end.subs(old, new);
173                Expression::Interval(Box::new(crate::core::expression::IntervalData {
174                    start: new_start,
175                    end: new_end,
176                    start_inclusive: data.start_inclusive,
177                    end_inclusive: data.end_inclusive,
178                }))
179            }
180
181            Expression::Calculus(data) => {
182                use crate::core::expression::CalculusData;
183
184                let new_data = match data.as_ref() {
185                    CalculusData::Derivative {
186                        expression,
187                        variable,
188                        order,
189                    } => CalculusData::Derivative {
190                        expression: expression.subs(old, new),
191                        variable: variable.clone(),
192                        order: *order,
193                    },
194
195                    CalculusData::Integral {
196                        integrand,
197                        variable,
198                        bounds,
199                    } => CalculusData::Integral {
200                        integrand: integrand.subs(old, new),
201                        variable: variable.clone(),
202                        bounds: bounds
203                            .as_ref()
204                            .map(|(a, b)| (a.subs(old, new), b.subs(old, new))),
205                    },
206
207                    CalculusData::Limit {
208                        expression,
209                        variable,
210                        point,
211                        direction,
212                    } => CalculusData::Limit {
213                        expression: expression.subs(old, new),
214                        variable: variable.clone(),
215                        point: point.subs(old, new),
216                        direction: *direction,
217                    },
218
219                    CalculusData::Sum {
220                        expression,
221                        variable,
222                        start,
223                        end,
224                    } => CalculusData::Sum {
225                        expression: expression.subs(old, new),
226                        variable: variable.clone(),
227                        start: start.subs(old, new),
228                        end: end.subs(old, new),
229                    },
230
231                    CalculusData::Product {
232                        expression,
233                        variable,
234                        start,
235                        end,
236                    } => CalculusData::Product {
237                        expression: expression.subs(old, new),
238                        variable: variable.clone(),
239                        start: start.subs(old, new),
240                        end: end.subs(old, new),
241                    },
242                };
243
244                Expression::Calculus(Box::new(new_data))
245            }
246
247            Expression::MethodCall(data) => {
248                let new_object = data.object.subs(old, new);
249                let new_args: Vec<Expression> =
250                    data.args.iter().map(|a| a.subs(old, new)).collect();
251
252                Expression::MethodCall(Box::new(crate::core::expression::MethodCallData {
253                    object: new_object,
254                    method_name: data.method_name.clone(),
255                    args: new_args,
256                }))
257            }
258        };
259
260        result.simplify()
261    }
262
263    fn subs_multiple(&self, substitutions: &[(Expression, Expression)]) -> Expression {
264        super::rewrite::subs_multiple_impl(self, substitutions)
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use crate::prelude::*;
272
273    #[test]
274    fn test_basic_symbol_substitution() {
275        let x = symbol!(x);
276        let expr = Expression::symbol(x.clone());
277
278        let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(5));
279
280        assert_eq!(result, Expression::integer(5));
281    }
282
283    #[test]
284    fn test_substitution_in_addition() {
285        let x = symbol!(x);
286        let expr = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
287
288        let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(5));
289
290        assert_eq!(result, Expression::integer(6));
291    }
292
293    #[test]
294    fn test_substitution_in_multiplication() {
295        let x = symbol!(x);
296        let expr = Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]);
297
298        let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(3));
299
300        assert_eq!(result, Expression::integer(6));
301    }
302
303    #[test]
304    fn test_substitution_in_power() {
305        let x = symbol!(x);
306        let expr = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
307
308        let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(3));
309
310        assert_eq!(result, Expression::integer(9));
311    }
312
313    #[test]
314    fn test_substitution_in_function() {
315        let x = symbol!(x);
316        let expr = Expression::function("sin".to_string(), vec![Expression::symbol(x.clone())]);
317
318        let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(0));
319
320        assert_eq!(result, Expression::integer(0));
321    }
322
323    #[test]
324    fn test_nested_substitution() {
325        let x = symbol!(x);
326        let expr = Expression::mul(vec![
327            Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]),
328            Expression::add(vec![
329                Expression::symbol(x.clone()),
330                Expression::mul(vec![Expression::integer(-1), Expression::integer(1)]),
331            ]),
332        ]);
333
334        let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(2));
335
336        assert_eq!(result, Expression::integer(3));
337    }
338
339    #[test]
340    fn test_no_substitution_when_not_present() {
341        let x = symbol!(x);
342        let y = symbol!(y);
343        let expr = Expression::symbol(y.clone());
344
345        let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(5));
346
347        assert_eq!(result, Expression::symbol(y.clone()));
348    }
349
350    #[test]
351    fn test_substitution_doesnt_recurse_into_replacement() {
352        let x = symbol!(x);
353        let y = symbol!(y);
354        let expr = Expression::symbol(x.clone());
355
356        let result = expr.subs(
357            &Expression::symbol(x.clone()),
358            &Expression::symbol(y.clone()),
359        );
360
361        assert_eq!(result, Expression::symbol(y.clone()));
362
363        let result2 = result.subs(&Expression::symbol(y.clone()), &Expression::integer(5));
364
365        assert_eq!(result2, Expression::integer(5));
366    }
367
368    #[test]
369    fn test_substitution_preserves_position_matrices() {
370        let a = symbol!(A; matrix);
371        let b = symbol!(B; matrix);
372        let c = symbol!(C; matrix);
373
374        let expr = Expression::mul(vec![
375            Expression::symbol(a.clone()),
376            Expression::symbol(b.clone()),
377            Expression::symbol(a.clone()),
378        ]);
379
380        let result = expr.subs(
381            &Expression::symbol(a.clone()),
382            &Expression::symbol(c.clone()),
383        );
384
385        let expected = Expression::mul(vec![
386            Expression::symbol(c.clone()),
387            Expression::symbol(b.clone()),
388            Expression::symbol(c.clone()),
389        ]);
390
391        assert_eq!(
392            result, expected,
393            "Substitution A->C in ABA must preserve positions to get CBC"
394        );
395    }
396
397    #[test]
398    fn test_substitution_preserves_position_operators() {
399        let p = symbol!(p; operator);
400        let x = symbol!(x; operator);
401        let h = symbol!(H; operator);
402
403        let expr = Expression::mul(vec![
404            Expression::symbol(p.clone()),
405            Expression::symbol(x.clone()),
406            Expression::symbol(p.clone()),
407        ]);
408
409        let result = expr.subs(
410            &Expression::symbol(p.clone()),
411            &Expression::symbol(h.clone()),
412        );
413
414        let expected = Expression::mul(vec![
415            Expression::symbol(h.clone()),
416            Expression::symbol(x.clone()),
417            Expression::symbol(h.clone()),
418        ]);
419
420        assert_eq!(
421            result, expected,
422            "Substitution p->H in pxp must preserve positions to get HxH"
423        );
424    }
425
426    #[test]
427    fn test_substitution_multiple_occurrences_different_positions() {
428        let a = symbol!(A; matrix);
429        let b = symbol!(B; matrix);
430        let c = symbol!(C; matrix);
431        let d = symbol!(D; matrix);
432
433        let expr = Expression::mul(vec![
434            Expression::symbol(a.clone()),
435            Expression::symbol(b.clone()),
436            Expression::symbol(c.clone()),
437            Expression::symbol(a.clone()),
438        ])
439        .simplify();
440
441        let result = expr.subs(
442            &Expression::symbol(a.clone()),
443            &Expression::symbol(d.clone()),
444        );
445
446        let expected = Expression::mul(vec![
447            Expression::symbol(d.clone()),
448            Expression::symbol(b.clone()),
449            Expression::symbol(c.clone()),
450            Expression::symbol(d.clone()),
451        ])
452        .simplify();
453
454        assert_eq!(
455            result, expected,
456            "Substitution A->D in ABCA must preserve all positions to get DBCD"
457        );
458    }
459
460    #[test]
461    fn test_substitution_quaternions_position_matters() {
462        let i = symbol!(i; quaternion);
463        let j = symbol!(j; quaternion);
464        let k = symbol!(k; quaternion);
465
466        let expr = Expression::mul(vec![
467            Expression::symbol(i.clone()),
468            Expression::symbol(j.clone()),
469            Expression::symbol(i.clone()),
470        ]);
471
472        let result = expr.subs(
473            &Expression::symbol(i.clone()),
474            &Expression::symbol(k.clone()),
475        );
476
477        let expected = Expression::mul(vec![
478            Expression::symbol(k.clone()),
479            Expression::symbol(j.clone()),
480            Expression::symbol(k.clone()),
481        ]);
482
483        assert_eq!(
484            result, expected,
485            "Substitution i->k in iji must preserve positions to get kjk"
486        );
487    }
488
489    #[test]
490    fn test_substitution_scalars_commutative_still_preserves_structure() {
491        let x = symbol!(x);
492        let y = symbol!(y);
493        let z = symbol!(z);
494
495        let expr = Expression::mul(vec![
496            Expression::symbol(x.clone()),
497            Expression::symbol(y.clone()),
498            Expression::symbol(x.clone()),
499        ]);
500
501        let result = expr.subs(
502            &Expression::symbol(x.clone()),
503            &Expression::symbol(z.clone()),
504        );
505
506        let expected = Expression::mul(vec![
507            Expression::symbol(z.clone()),
508            Expression::symbol(y.clone()),
509            Expression::symbol(z.clone()),
510        ]);
511
512        assert_eq!(
513            result, expected,
514            "Substitution x->z in xyx preserves structure even for commutative scalars"
515        );
516    }
517}