1use crate::core::Expression;
4use crate::simplify::Simplify;
5
6pub trait Substitutable {
8 fn subs(&self, old: &Expression, new: &Expression) -> Expression;
40
41 fn subs_multiple(&self, substitutions: &[(Expression, Expression)]) -> Expression;
76}
77
78impl Substitutable for Expression {
79 fn subs(&self, old: &Expression, new: &Expression) -> Expression {
80 if self == old {
81 return new.clone();
82 }
83
84 let result = match self {
85 Expression::Number(_) | Expression::Constant(_) => self.clone(),
86
87 Expression::Symbol(_) => self.clone(),
88
89 Expression::Add(terms) => {
90 let new_terms: Vec<Expression> = terms.iter().map(|t| t.subs(old, new)).collect();
91 Expression::Add(Box::new(new_terms))
92 }
93
94 Expression::Mul(factors) => {
95 let new_factors: Vec<Expression> =
96 factors.iter().map(|f| f.subs(old, new)).collect();
97 Expression::mul(new_factors)
98 }
99
100 Expression::Pow(base, exp) => {
101 let new_base = base.subs(old, new);
102 let new_exp = exp.subs(old, new);
103 Expression::Pow(Box::new(new_base), Box::new(new_exp))
104 }
105
106 Expression::Function { name, args } => {
107 let new_args: Vec<Expression> = args.iter().map(|a| a.subs(old, new)).collect();
108 Expression::Function {
109 name: name.clone(),
110 args: Box::new(new_args),
111 }
112 }
113
114 Expression::Set(elements) => {
115 let new_elements: Vec<Expression> =
116 elements.iter().map(|e| e.subs(old, new)).collect();
117 Expression::Set(Box::new(new_elements))
118 }
119
120 Expression::Complex(data) => {
121 let new_real = data.real.subs(old, new);
122 let new_imag = data.imag.subs(old, new);
123 Expression::Complex(Box::new(crate::core::expression::ComplexData {
124 real: new_real,
125 imag: new_imag,
126 }))
127 }
128
129 Expression::Matrix(matrix) => {
130 let (rows, cols) = matrix.dimensions();
131 let mut new_data: Vec<Vec<Expression>> = Vec::with_capacity(rows);
132
133 for i in 0..rows {
134 let mut row: Vec<Expression> = Vec::with_capacity(cols);
135 for j in 0..cols {
136 let elem = matrix.get_element(i, j);
137 row.push(elem.subs(old, new));
138 }
139 new_data.push(row);
140 }
141
142 Expression::Matrix(Box::new(crate::matrices::unified::Matrix::dense(new_data)))
143 }
144
145 Expression::Relation(data) => {
146 let new_left = data.left.subs(old, new);
147 let new_right = data.right.subs(old, new);
148 Expression::Relation(Box::new(crate::core::expression::RelationData {
149 left: new_left,
150 right: new_right,
151 relation_type: data.relation_type,
152 }))
153 }
154
155 Expression::Piecewise(data) => {
156 let new_pieces: Vec<(Expression, Expression)> = data
157 .pieces
158 .iter()
159 .map(|(expr, cond)| (expr.subs(old, new), cond.subs(old, new)))
160 .collect();
161
162 let new_default = data.default.as_ref().map(|d| d.subs(old, new));
163
164 Expression::Piecewise(Box::new(crate::core::expression::PiecewiseData {
165 pieces: new_pieces,
166 default: new_default,
167 }))
168 }
169
170 Expression::Interval(data) => {
171 let new_start = data.start.subs(old, new);
172 let new_end = data.end.subs(old, new);
173 Expression::Interval(Box::new(crate::core::expression::IntervalData {
174 start: new_start,
175 end: new_end,
176 start_inclusive: data.start_inclusive,
177 end_inclusive: data.end_inclusive,
178 }))
179 }
180
181 Expression::Calculus(data) => {
182 use crate::core::expression::CalculusData;
183
184 let new_data = match data.as_ref() {
185 CalculusData::Derivative {
186 expression,
187 variable,
188 order,
189 } => CalculusData::Derivative {
190 expression: expression.subs(old, new),
191 variable: variable.clone(),
192 order: *order,
193 },
194
195 CalculusData::Integral {
196 integrand,
197 variable,
198 bounds,
199 } => CalculusData::Integral {
200 integrand: integrand.subs(old, new),
201 variable: variable.clone(),
202 bounds: bounds
203 .as_ref()
204 .map(|(a, b)| (a.subs(old, new), b.subs(old, new))),
205 },
206
207 CalculusData::Limit {
208 expression,
209 variable,
210 point,
211 direction,
212 } => CalculusData::Limit {
213 expression: expression.subs(old, new),
214 variable: variable.clone(),
215 point: point.subs(old, new),
216 direction: *direction,
217 },
218
219 CalculusData::Sum {
220 expression,
221 variable,
222 start,
223 end,
224 } => CalculusData::Sum {
225 expression: expression.subs(old, new),
226 variable: variable.clone(),
227 start: start.subs(old, new),
228 end: end.subs(old, new),
229 },
230
231 CalculusData::Product {
232 expression,
233 variable,
234 start,
235 end,
236 } => CalculusData::Product {
237 expression: expression.subs(old, new),
238 variable: variable.clone(),
239 start: start.subs(old, new),
240 end: end.subs(old, new),
241 },
242 };
243
244 Expression::Calculus(Box::new(new_data))
245 }
246
247 Expression::MethodCall(data) => {
248 let new_object = data.object.subs(old, new);
249 let new_args: Vec<Expression> =
250 data.args.iter().map(|a| a.subs(old, new)).collect();
251
252 Expression::MethodCall(Box::new(crate::core::expression::MethodCallData {
253 object: new_object,
254 method_name: data.method_name.clone(),
255 args: new_args,
256 }))
257 }
258 };
259
260 result.simplify()
261 }
262
263 fn subs_multiple(&self, substitutions: &[(Expression, Expression)]) -> Expression {
264 super::rewrite::subs_multiple_impl(self, substitutions)
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use crate::prelude::*;
272
273 #[test]
274 fn test_basic_symbol_substitution() {
275 let x = symbol!(x);
276 let expr = Expression::symbol(x.clone());
277
278 let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(5));
279
280 assert_eq!(result, Expression::integer(5));
281 }
282
283 #[test]
284 fn test_substitution_in_addition() {
285 let x = symbol!(x);
286 let expr = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
287
288 let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(5));
289
290 assert_eq!(result, Expression::integer(6));
291 }
292
293 #[test]
294 fn test_substitution_in_multiplication() {
295 let x = symbol!(x);
296 let expr = Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]);
297
298 let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(3));
299
300 assert_eq!(result, Expression::integer(6));
301 }
302
303 #[test]
304 fn test_substitution_in_power() {
305 let x = symbol!(x);
306 let expr = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
307
308 let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(3));
309
310 assert_eq!(result, Expression::integer(9));
311 }
312
313 #[test]
314 fn test_substitution_in_function() {
315 let x = symbol!(x);
316 let expr = Expression::function("sin".to_string(), vec![Expression::symbol(x.clone())]);
317
318 let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(0));
319
320 assert_eq!(result, Expression::integer(0));
321 }
322
323 #[test]
324 fn test_nested_substitution() {
325 let x = symbol!(x);
326 let expr = Expression::mul(vec![
327 Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]),
328 Expression::add(vec![
329 Expression::symbol(x.clone()),
330 Expression::mul(vec![Expression::integer(-1), Expression::integer(1)]),
331 ]),
332 ]);
333
334 let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(2));
335
336 assert_eq!(result, Expression::integer(3));
337 }
338
339 #[test]
340 fn test_no_substitution_when_not_present() {
341 let x = symbol!(x);
342 let y = symbol!(y);
343 let expr = Expression::symbol(y.clone());
344
345 let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(5));
346
347 assert_eq!(result, Expression::symbol(y.clone()));
348 }
349
350 #[test]
351 fn test_substitution_doesnt_recurse_into_replacement() {
352 let x = symbol!(x);
353 let y = symbol!(y);
354 let expr = Expression::symbol(x.clone());
355
356 let result = expr.subs(
357 &Expression::symbol(x.clone()),
358 &Expression::symbol(y.clone()),
359 );
360
361 assert_eq!(result, Expression::symbol(y.clone()));
362
363 let result2 = result.subs(&Expression::symbol(y.clone()), &Expression::integer(5));
364
365 assert_eq!(result2, Expression::integer(5));
366 }
367
368 #[test]
369 fn test_substitution_preserves_position_matrices() {
370 let a = symbol!(A; matrix);
371 let b = symbol!(B; matrix);
372 let c = symbol!(C; matrix);
373
374 let expr = Expression::mul(vec![
375 Expression::symbol(a.clone()),
376 Expression::symbol(b.clone()),
377 Expression::symbol(a.clone()),
378 ]);
379
380 let result = expr.subs(
381 &Expression::symbol(a.clone()),
382 &Expression::symbol(c.clone()),
383 );
384
385 let expected = Expression::mul(vec![
386 Expression::symbol(c.clone()),
387 Expression::symbol(b.clone()),
388 Expression::symbol(c.clone()),
389 ]);
390
391 assert_eq!(
392 result, expected,
393 "Substitution A->C in ABA must preserve positions to get CBC"
394 );
395 }
396
397 #[test]
398 fn test_substitution_preserves_position_operators() {
399 let p = symbol!(p; operator);
400 let x = symbol!(x; operator);
401 let h = symbol!(H; operator);
402
403 let expr = Expression::mul(vec![
404 Expression::symbol(p.clone()),
405 Expression::symbol(x.clone()),
406 Expression::symbol(p.clone()),
407 ]);
408
409 let result = expr.subs(
410 &Expression::symbol(p.clone()),
411 &Expression::symbol(h.clone()),
412 );
413
414 let expected = Expression::mul(vec![
415 Expression::symbol(h.clone()),
416 Expression::symbol(x.clone()),
417 Expression::symbol(h.clone()),
418 ]);
419
420 assert_eq!(
421 result, expected,
422 "Substitution p->H in pxp must preserve positions to get HxH"
423 );
424 }
425
426 #[test]
427 fn test_substitution_multiple_occurrences_different_positions() {
428 let a = symbol!(A; matrix);
429 let b = symbol!(B; matrix);
430 let c = symbol!(C; matrix);
431 let d = symbol!(D; matrix);
432
433 let expr = Expression::mul(vec![
434 Expression::symbol(a.clone()),
435 Expression::symbol(b.clone()),
436 Expression::symbol(c.clone()),
437 Expression::symbol(a.clone()),
438 ])
439 .simplify();
440
441 let result = expr.subs(
442 &Expression::symbol(a.clone()),
443 &Expression::symbol(d.clone()),
444 );
445
446 let expected = Expression::mul(vec![
447 Expression::symbol(d.clone()),
448 Expression::symbol(b.clone()),
449 Expression::symbol(c.clone()),
450 Expression::symbol(d.clone()),
451 ])
452 .simplify();
453
454 assert_eq!(
455 result, expected,
456 "Substitution A->D in ABCA must preserve all positions to get DBCD"
457 );
458 }
459
460 #[test]
461 fn test_substitution_quaternions_position_matters() {
462 let i = symbol!(i; quaternion);
463 let j = symbol!(j; quaternion);
464 let k = symbol!(k; quaternion);
465
466 let expr = Expression::mul(vec![
467 Expression::symbol(i.clone()),
468 Expression::symbol(j.clone()),
469 Expression::symbol(i.clone()),
470 ]);
471
472 let result = expr.subs(
473 &Expression::symbol(i.clone()),
474 &Expression::symbol(k.clone()),
475 );
476
477 let expected = Expression::mul(vec![
478 Expression::symbol(k.clone()),
479 Expression::symbol(j.clone()),
480 Expression::symbol(k.clone()),
481 ]);
482
483 assert_eq!(
484 result, expected,
485 "Substitution i->k in iji must preserve positions to get kjk"
486 );
487 }
488
489 #[test]
490 fn test_substitution_scalars_commutative_still_preserves_structure() {
491 let x = symbol!(x);
492 let y = symbol!(y);
493 let z = symbol!(z);
494
495 let expr = Expression::mul(vec![
496 Expression::symbol(x.clone()),
497 Expression::symbol(y.clone()),
498 Expression::symbol(x.clone()),
499 ]);
500
501 let result = expr.subs(
502 &Expression::symbol(x.clone()),
503 &Expression::symbol(z.clone()),
504 );
505
506 let expected = Expression::mul(vec![
507 Expression::symbol(z.clone()),
508 Expression::symbol(y.clone()),
509 Expression::symbol(z.clone()),
510 ]);
511
512 assert_eq!(
513 result, expected,
514 "Substitution x->z in xyx preserves structure even for commutative scalars"
515 );
516 }
517}