mathhook_core/pattern/substitution/
core.rs

1//! Core substitution trait and single-expression substitution implementation
2
3use crate::core::Expression;
4use crate::simplify::Simplify;
5use std::sync::Arc;
6
7/// Trait for types that support substitution operations
8pub trait Substitutable {
9    /// Substitute a single expression with another
10    ///
11    /// Recursively walks the expression tree and replaces all occurrences
12    /// of `old` with `new`. The replacement is structural - it compares
13    /// expressions using PartialEq.
14    ///
15    /// # Arguments
16    ///
17    /// * `old` - The expression to replace
18    /// * `new` - The expression to substitute in
19    ///
20    /// # Examples
21    ///
22    /// ```
23    /// use mathhook_core::prelude::*;
24    /// use mathhook_core::pattern::Substitutable;
25    ///
26    /// let x = symbol!(x);
27    /// let expr = Expression::add(vec![
28    ///     Expression::symbol(x.clone()),
29    ///     Expression::integer(1)
30    /// ]);
31    ///
32    /// let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(5));
33    /// let expected = Expression::add(vec![
34    ///     Expression::integer(5),
35    ///     Expression::integer(1)
36    /// ]);
37    ///
38    /// assert_eq!(result, expected);
39    /// ```
40    fn subs(&self, old: &Expression, new: &Expression) -> Expression;
41
42    /// Apply multiple substitutions simultaneously
43    ///
44    /// This is more efficient than chaining multiple `subs()` calls because
45    /// it performs all substitutions in a single tree traversal.
46    ///
47    /// # Arguments
48    ///
49    /// * `substitutions` - Slice of (old, new) expression pairs
50    ///
51    /// # Examples
52    ///
53    /// ```
54    /// use mathhook_core::prelude::*;
55    /// use mathhook_core::pattern::Substitutable;
56    ///
57    /// let x = symbol!(x);
58    /// let y = symbol!(y);
59    /// let expr = Expression::add(vec![
60    ///     Expression::symbol(x.clone()),
61    ///     Expression::symbol(y.clone())
62    /// ]);
63    ///
64    /// let result = expr.subs_multiple(&[
65    ///     (Expression::symbol(x.clone()), Expression::integer(1)),
66    ///     (Expression::symbol(y.clone()), Expression::integer(2)),
67    /// ]);
68    ///
69    /// let expected = Expression::add(vec![
70    ///     Expression::integer(1),
71    ///     Expression::integer(2)
72    /// ]);
73    ///
74    /// assert_eq!(result, expected);
75    /// ```
76    fn subs_multiple(&self, substitutions: &[(Expression, Expression)]) -> Expression;
77}
78
79impl Substitutable for Expression {
80    fn subs(&self, old: &Expression, new: &Expression) -> Expression {
81        if self == old {
82            return new.clone();
83        }
84
85        let result = match self {
86            Expression::Number(_) | Expression::Constant(_) => self.clone(),
87
88            Expression::Symbol(_) => self.clone(),
89
90            Expression::Add(terms) => {
91                let new_terms: Vec<Expression> = terms.iter().map(|t| t.subs(old, new)).collect();
92                Expression::Add(Arc::new(new_terms))
93            }
94
95            Expression::Mul(factors) => {
96                let new_factors: Vec<Expression> =
97                    factors.iter().map(|f| f.subs(old, new)).collect();
98                Expression::mul(new_factors)
99            }
100
101            Expression::Pow(base, exp) => {
102                let new_base = base.subs(old, new);
103                let new_exp = exp.subs(old, new);
104                Expression::Pow(Arc::new(new_base), Arc::new(new_exp))
105            }
106
107            Expression::Function { name, args } => {
108                let new_args: Vec<Expression> = args.iter().map(|a| a.subs(old, new)).collect();
109                Expression::Function {
110                    name: name.clone(),
111                    args: Arc::new(new_args),
112                }
113            }
114
115            Expression::Set(elements) => {
116                let new_elements: Vec<Expression> =
117                    elements.iter().map(|e| e.subs(old, new)).collect();
118                Expression::Set(Arc::new(new_elements))
119            }
120
121            Expression::Complex(data) => {
122                let new_real = data.real.subs(old, new);
123                let new_imag = data.imag.subs(old, new);
124                Expression::Complex(Arc::new(crate::core::expression::ComplexData {
125                    real: new_real,
126                    imag: new_imag,
127                }))
128            }
129
130            Expression::Matrix(matrix) => {
131                let (rows, cols) = matrix.dimensions();
132                let mut new_data: Vec<Vec<Expression>> = Vec::with_capacity(rows);
133
134                for i in 0..rows {
135                    let mut row: Vec<Expression> = Vec::with_capacity(cols);
136                    for j in 0..cols {
137                        let elem = matrix.get_element(i, j);
138                        row.push(elem.subs(old, new));
139                    }
140                    new_data.push(row);
141                }
142
143                Expression::Matrix(Arc::new(crate::matrices::unified::Matrix::dense(new_data)))
144            }
145
146            Expression::Relation(data) => {
147                let new_left = data.left.subs(old, new);
148                let new_right = data.right.subs(old, new);
149                Expression::Relation(Arc::new(crate::core::expression::RelationData {
150                    left: new_left,
151                    right: new_right,
152                    relation_type: data.relation_type,
153                }))
154            }
155
156            Expression::Piecewise(data) => {
157                let new_pieces: Vec<(Expression, Expression)> = data
158                    .pieces
159                    .iter()
160                    .map(|(expr, cond)| (expr.subs(old, new), cond.subs(old, new)))
161                    .collect();
162
163                let new_default = data.default.as_ref().map(|d| d.subs(old, new));
164
165                Expression::Piecewise(Arc::new(crate::core::expression::PiecewiseData {
166                    pieces: new_pieces,
167                    default: new_default,
168                }))
169            }
170
171            Expression::Interval(data) => {
172                let new_start = data.start.subs(old, new);
173                let new_end = data.end.subs(old, new);
174                Expression::Interval(Arc::new(crate::core::expression::IntervalData {
175                    start: new_start,
176                    end: new_end,
177                    start_inclusive: data.start_inclusive,
178                    end_inclusive: data.end_inclusive,
179                }))
180            }
181
182            Expression::Calculus(data) => {
183                use crate::core::expression::CalculusData;
184
185                let new_data = match data.as_ref() {
186                    CalculusData::Derivative {
187                        expression,
188                        variable,
189                        order,
190                    } => CalculusData::Derivative {
191                        expression: expression.subs(old, new),
192                        variable: variable.clone(),
193                        order: *order,
194                    },
195
196                    CalculusData::Integral {
197                        integrand,
198                        variable,
199                        bounds,
200                    } => CalculusData::Integral {
201                        integrand: integrand.subs(old, new),
202                        variable: variable.clone(),
203                        bounds: bounds
204                            .as_ref()
205                            .map(|(a, b)| (a.subs(old, new), b.subs(old, new))),
206                    },
207
208                    CalculusData::Limit {
209                        expression,
210                        variable,
211                        point,
212                        direction,
213                    } => CalculusData::Limit {
214                        expression: expression.subs(old, new),
215                        variable: variable.clone(),
216                        point: point.subs(old, new),
217                        direction: *direction,
218                    },
219
220                    CalculusData::Sum {
221                        expression,
222                        variable,
223                        start,
224                        end,
225                    } => CalculusData::Sum {
226                        expression: expression.subs(old, new),
227                        variable: variable.clone(),
228                        start: start.subs(old, new),
229                        end: end.subs(old, new),
230                    },
231
232                    CalculusData::Product {
233                        expression,
234                        variable,
235                        start,
236                        end,
237                    } => CalculusData::Product {
238                        expression: expression.subs(old, new),
239                        variable: variable.clone(),
240                        start: start.subs(old, new),
241                        end: end.subs(old, new),
242                    },
243                };
244
245                Expression::Calculus(Arc::new(new_data))
246            }
247
248            Expression::MethodCall(data) => {
249                let new_object = data.object.subs(old, new);
250                let new_args: Vec<Expression> =
251                    data.args.iter().map(|a| a.subs(old, new)).collect();
252
253                Expression::MethodCall(Arc::new(crate::core::expression::MethodCallData {
254                    object: new_object,
255                    method_name: data.method_name.clone(),
256                    args: new_args,
257                }))
258            }
259        };
260
261        result.simplify()
262    }
263
264    fn subs_multiple(&self, substitutions: &[(Expression, Expression)]) -> Expression {
265        super::rewrite::subs_multiple_impl(self, substitutions)
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::prelude::*;
273
274    #[test]
275    fn test_basic_symbol_substitution() {
276        let x = symbol!(x);
277        let expr = Expression::symbol(x.clone());
278
279        let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(5));
280
281        assert_eq!(result, Expression::integer(5));
282    }
283
284    #[test]
285    fn test_substitution_in_addition() {
286        let x = symbol!(x);
287        let expr = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
288
289        let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(5));
290
291        assert_eq!(result, Expression::integer(6));
292    }
293
294    #[test]
295    fn test_substitution_in_multiplication() {
296        let x = symbol!(x);
297        let expr = Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]);
298
299        let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(3));
300
301        assert_eq!(result, Expression::integer(6));
302    }
303
304    #[test]
305    fn test_substitution_in_power() {
306        let x = symbol!(x);
307        let expr = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
308
309        let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(3));
310
311        assert_eq!(result, Expression::integer(9));
312    }
313
314    #[test]
315    fn test_substitution_in_function() {
316        let x = symbol!(x);
317        let expr = Expression::function("sin", vec![Expression::symbol(x.clone())]);
318
319        let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(0));
320
321        assert_eq!(result, Expression::integer(0));
322    }
323
324    #[test]
325    fn test_nested_substitution() {
326        let x = symbol!(x);
327        let expr = Expression::mul(vec![
328            Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]),
329            Expression::add(vec![
330                Expression::symbol(x.clone()),
331                Expression::mul(vec![Expression::integer(-1), Expression::integer(1)]),
332            ]),
333        ]);
334
335        let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(2));
336
337        assert_eq!(result, Expression::integer(3));
338    }
339
340    #[test]
341    fn test_no_substitution_when_not_present() {
342        let x = symbol!(x);
343        let y = symbol!(y);
344        let expr = Expression::symbol(y.clone());
345
346        let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(5));
347
348        assert_eq!(result, Expression::symbol(y.clone()));
349    }
350
351    #[test]
352    fn test_substitution_doesnt_recurse_into_replacement() {
353        let x = symbol!(x);
354        let y = symbol!(y);
355        let expr = Expression::symbol(x.clone());
356
357        let result = expr.subs(
358            &Expression::symbol(x.clone()),
359            &Expression::symbol(y.clone()),
360        );
361
362        assert_eq!(result, Expression::symbol(y.clone()));
363
364        let result2 = result.subs(&Expression::symbol(y.clone()), &Expression::integer(5));
365
366        assert_eq!(result2, Expression::integer(5));
367    }
368
369    #[test]
370    fn test_substitution_preserves_position_matrices() {
371        let a = symbol!(A; matrix);
372        let b = symbol!(B; matrix);
373        let c = symbol!(C; matrix);
374
375        let expr = Expression::mul(vec![
376            Expression::symbol(a.clone()),
377            Expression::symbol(b.clone()),
378            Expression::symbol(a.clone()),
379        ]);
380
381        let result = expr.subs(
382            &Expression::symbol(a.clone()),
383            &Expression::symbol(c.clone()),
384        );
385
386        let expected = Expression::mul(vec![
387            Expression::symbol(c.clone()),
388            Expression::symbol(b.clone()),
389            Expression::symbol(c.clone()),
390        ]);
391
392        assert_eq!(
393            result, expected,
394            "Substitution A->C in ABA must preserve positions to get CBC"
395        );
396    }
397
398    #[test]
399    fn test_substitution_preserves_position_operators() {
400        let p = symbol!(p; operator);
401        let x = symbol!(x; operator);
402        let h = symbol!(H; operator);
403
404        let expr = Expression::mul(vec![
405            Expression::symbol(p.clone()),
406            Expression::symbol(x.clone()),
407            Expression::symbol(p.clone()),
408        ]);
409
410        let result = expr.subs(
411            &Expression::symbol(p.clone()),
412            &Expression::symbol(h.clone()),
413        );
414
415        let expected = Expression::mul(vec![
416            Expression::symbol(h.clone()),
417            Expression::symbol(x.clone()),
418            Expression::symbol(h.clone()),
419        ]);
420
421        assert_eq!(
422            result, expected,
423            "Substitution p->H in pxp must preserve positions to get HxH"
424        );
425    }
426
427    #[test]
428    fn test_substitution_multiple_occurrences_different_positions() {
429        let a = symbol!(A; matrix);
430        let b = symbol!(B; matrix);
431        let c = symbol!(C; matrix);
432        let d = symbol!(D; matrix);
433
434        let expr = Expression::mul(vec![
435            Expression::symbol(a.clone()),
436            Expression::symbol(b.clone()),
437            Expression::symbol(c.clone()),
438            Expression::symbol(a.clone()),
439        ])
440        .simplify();
441
442        let result = expr.subs(
443            &Expression::symbol(a.clone()),
444            &Expression::symbol(d.clone()),
445        );
446
447        let expected = Expression::mul(vec![
448            Expression::symbol(d.clone()),
449            Expression::symbol(b.clone()),
450            Expression::symbol(c.clone()),
451            Expression::symbol(d.clone()),
452        ])
453        .simplify();
454
455        assert_eq!(
456            result, expected,
457            "Substitution A->D in ABCA must preserve all positions to get DBCD"
458        );
459    }
460
461    #[test]
462    fn test_substitution_quaternions_position_matters() {
463        let i = symbol!(i; quaternion);
464        let j = symbol!(j; quaternion);
465        let k = symbol!(k; quaternion);
466
467        let expr = Expression::mul(vec![
468            Expression::symbol(i.clone()),
469            Expression::symbol(j.clone()),
470            Expression::symbol(i.clone()),
471        ]);
472
473        let result = expr.subs(
474            &Expression::symbol(i.clone()),
475            &Expression::symbol(k.clone()),
476        );
477
478        let expected = Expression::mul(vec![
479            Expression::symbol(k.clone()),
480            Expression::symbol(j.clone()),
481            Expression::symbol(k.clone()),
482        ]);
483
484        assert_eq!(
485            result, expected,
486            "Substitution i->k in iji must preserve positions to get kjk"
487        );
488    }
489
490    #[test]
491    fn test_substitution_scalars_commutative_still_preserves_structure() {
492        let x = symbol!(x);
493        let y = symbol!(y);
494        let z = symbol!(z);
495
496        let expr = Expression::mul(vec![
497            Expression::symbol(x.clone()),
498            Expression::symbol(y.clone()),
499            Expression::symbol(x.clone()),
500        ]);
501
502        let result = expr.subs(
503            &Expression::symbol(x.clone()),
504            &Expression::symbol(z.clone()),
505        );
506
507        let expected = Expression::mul(vec![
508            Expression::symbol(z.clone()),
509            Expression::symbol(y.clone()),
510            Expression::symbol(z.clone()),
511        ]);
512
513        assert_eq!(
514            result, expected,
515            "Substitution x->z in xyx preserves structure even for commutative scalars"
516        );
517    }
518}