1use crate::core::Expression;
4use crate::simplify::Simplify;
5use std::sync::Arc;
6
7pub trait Substitutable {
9 fn subs(&self, old: &Expression, new: &Expression) -> Expression;
41
42 fn subs_multiple(&self, substitutions: &[(Expression, Expression)]) -> Expression;
77}
78
79impl Substitutable for Expression {
80 fn subs(&self, old: &Expression, new: &Expression) -> Expression {
81 if self == old {
82 return new.clone();
83 }
84
85 let result = match self {
86 Expression::Number(_) | Expression::Constant(_) => self.clone(),
87
88 Expression::Symbol(_) => self.clone(),
89
90 Expression::Add(terms) => {
91 let new_terms: Vec<Expression> = terms.iter().map(|t| t.subs(old, new)).collect();
92 Expression::Add(Arc::new(new_terms))
93 }
94
95 Expression::Mul(factors) => {
96 let new_factors: Vec<Expression> =
97 factors.iter().map(|f| f.subs(old, new)).collect();
98 Expression::mul(new_factors)
99 }
100
101 Expression::Pow(base, exp) => {
102 let new_base = base.subs(old, new);
103 let new_exp = exp.subs(old, new);
104 Expression::Pow(Arc::new(new_base), Arc::new(new_exp))
105 }
106
107 Expression::Function { name, args } => {
108 let new_args: Vec<Expression> = args.iter().map(|a| a.subs(old, new)).collect();
109 Expression::Function {
110 name: name.clone(),
111 args: Arc::new(new_args),
112 }
113 }
114
115 Expression::Set(elements) => {
116 let new_elements: Vec<Expression> =
117 elements.iter().map(|e| e.subs(old, new)).collect();
118 Expression::Set(Arc::new(new_elements))
119 }
120
121 Expression::Complex(data) => {
122 let new_real = data.real.subs(old, new);
123 let new_imag = data.imag.subs(old, new);
124 Expression::Complex(Arc::new(crate::core::expression::ComplexData {
125 real: new_real,
126 imag: new_imag,
127 }))
128 }
129
130 Expression::Matrix(matrix) => {
131 let (rows, cols) = matrix.dimensions();
132 let mut new_data: Vec<Vec<Expression>> = Vec::with_capacity(rows);
133
134 for i in 0..rows {
135 let mut row: Vec<Expression> = Vec::with_capacity(cols);
136 for j in 0..cols {
137 let elem = matrix.get_element(i, j);
138 row.push(elem.subs(old, new));
139 }
140 new_data.push(row);
141 }
142
143 Expression::Matrix(Arc::new(crate::matrices::unified::Matrix::dense(new_data)))
144 }
145
146 Expression::Relation(data) => {
147 let new_left = data.left.subs(old, new);
148 let new_right = data.right.subs(old, new);
149 Expression::Relation(Arc::new(crate::core::expression::RelationData {
150 left: new_left,
151 right: new_right,
152 relation_type: data.relation_type,
153 }))
154 }
155
156 Expression::Piecewise(data) => {
157 let new_pieces: Vec<(Expression, Expression)> = data
158 .pieces
159 .iter()
160 .map(|(expr, cond)| (expr.subs(old, new), cond.subs(old, new)))
161 .collect();
162
163 let new_default = data.default.as_ref().map(|d| d.subs(old, new));
164
165 Expression::Piecewise(Arc::new(crate::core::expression::PiecewiseData {
166 pieces: new_pieces,
167 default: new_default,
168 }))
169 }
170
171 Expression::Interval(data) => {
172 let new_start = data.start.subs(old, new);
173 let new_end = data.end.subs(old, new);
174 Expression::Interval(Arc::new(crate::core::expression::IntervalData {
175 start: new_start,
176 end: new_end,
177 start_inclusive: data.start_inclusive,
178 end_inclusive: data.end_inclusive,
179 }))
180 }
181
182 Expression::Calculus(data) => {
183 use crate::core::expression::CalculusData;
184
185 let new_data = match data.as_ref() {
186 CalculusData::Derivative {
187 expression,
188 variable,
189 order,
190 } => CalculusData::Derivative {
191 expression: expression.subs(old, new),
192 variable: variable.clone(),
193 order: *order,
194 },
195
196 CalculusData::Integral {
197 integrand,
198 variable,
199 bounds,
200 } => CalculusData::Integral {
201 integrand: integrand.subs(old, new),
202 variable: variable.clone(),
203 bounds: bounds
204 .as_ref()
205 .map(|(a, b)| (a.subs(old, new), b.subs(old, new))),
206 },
207
208 CalculusData::Limit {
209 expression,
210 variable,
211 point,
212 direction,
213 } => CalculusData::Limit {
214 expression: expression.subs(old, new),
215 variable: variable.clone(),
216 point: point.subs(old, new),
217 direction: *direction,
218 },
219
220 CalculusData::Sum {
221 expression,
222 variable,
223 start,
224 end,
225 } => CalculusData::Sum {
226 expression: expression.subs(old, new),
227 variable: variable.clone(),
228 start: start.subs(old, new),
229 end: end.subs(old, new),
230 },
231
232 CalculusData::Product {
233 expression,
234 variable,
235 start,
236 end,
237 } => CalculusData::Product {
238 expression: expression.subs(old, new),
239 variable: variable.clone(),
240 start: start.subs(old, new),
241 end: end.subs(old, new),
242 },
243 };
244
245 Expression::Calculus(Arc::new(new_data))
246 }
247
248 Expression::MethodCall(data) => {
249 let new_object = data.object.subs(old, new);
250 let new_args: Vec<Expression> =
251 data.args.iter().map(|a| a.subs(old, new)).collect();
252
253 Expression::MethodCall(Arc::new(crate::core::expression::MethodCallData {
254 object: new_object,
255 method_name: data.method_name.clone(),
256 args: new_args,
257 }))
258 }
259 };
260
261 result.simplify()
262 }
263
264 fn subs_multiple(&self, substitutions: &[(Expression, Expression)]) -> Expression {
265 super::rewrite::subs_multiple_impl(self, substitutions)
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::prelude::*;
273
274 #[test]
275 fn test_basic_symbol_substitution() {
276 let x = symbol!(x);
277 let expr = Expression::symbol(x.clone());
278
279 let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(5));
280
281 assert_eq!(result, Expression::integer(5));
282 }
283
284 #[test]
285 fn test_substitution_in_addition() {
286 let x = symbol!(x);
287 let expr = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]);
288
289 let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(5));
290
291 assert_eq!(result, Expression::integer(6));
292 }
293
294 #[test]
295 fn test_substitution_in_multiplication() {
296 let x = symbol!(x);
297 let expr = Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]);
298
299 let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(3));
300
301 assert_eq!(result, Expression::integer(6));
302 }
303
304 #[test]
305 fn test_substitution_in_power() {
306 let x = symbol!(x);
307 let expr = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
308
309 let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(3));
310
311 assert_eq!(result, Expression::integer(9));
312 }
313
314 #[test]
315 fn test_substitution_in_function() {
316 let x = symbol!(x);
317 let expr = Expression::function("sin", vec![Expression::symbol(x.clone())]);
318
319 let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(0));
320
321 assert_eq!(result, Expression::integer(0));
322 }
323
324 #[test]
325 fn test_nested_substitution() {
326 let x = symbol!(x);
327 let expr = Expression::mul(vec![
328 Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(1)]),
329 Expression::add(vec![
330 Expression::symbol(x.clone()),
331 Expression::mul(vec![Expression::integer(-1), Expression::integer(1)]),
332 ]),
333 ]);
334
335 let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(2));
336
337 assert_eq!(result, Expression::integer(3));
338 }
339
340 #[test]
341 fn test_no_substitution_when_not_present() {
342 let x = symbol!(x);
343 let y = symbol!(y);
344 let expr = Expression::symbol(y.clone());
345
346 let result = expr.subs(&Expression::symbol(x.clone()), &Expression::integer(5));
347
348 assert_eq!(result, Expression::symbol(y.clone()));
349 }
350
351 #[test]
352 fn test_substitution_doesnt_recurse_into_replacement() {
353 let x = symbol!(x);
354 let y = symbol!(y);
355 let expr = Expression::symbol(x.clone());
356
357 let result = expr.subs(
358 &Expression::symbol(x.clone()),
359 &Expression::symbol(y.clone()),
360 );
361
362 assert_eq!(result, Expression::symbol(y.clone()));
363
364 let result2 = result.subs(&Expression::symbol(y.clone()), &Expression::integer(5));
365
366 assert_eq!(result2, Expression::integer(5));
367 }
368
369 #[test]
370 fn test_substitution_preserves_position_matrices() {
371 let a = symbol!(A; matrix);
372 let b = symbol!(B; matrix);
373 let c = symbol!(C; matrix);
374
375 let expr = Expression::mul(vec![
376 Expression::symbol(a.clone()),
377 Expression::symbol(b.clone()),
378 Expression::symbol(a.clone()),
379 ]);
380
381 let result = expr.subs(
382 &Expression::symbol(a.clone()),
383 &Expression::symbol(c.clone()),
384 );
385
386 let expected = Expression::mul(vec![
387 Expression::symbol(c.clone()),
388 Expression::symbol(b.clone()),
389 Expression::symbol(c.clone()),
390 ]);
391
392 assert_eq!(
393 result, expected,
394 "Substitution A->C in ABA must preserve positions to get CBC"
395 );
396 }
397
398 #[test]
399 fn test_substitution_preserves_position_operators() {
400 let p = symbol!(p; operator);
401 let x = symbol!(x; operator);
402 let h = symbol!(H; operator);
403
404 let expr = Expression::mul(vec![
405 Expression::symbol(p.clone()),
406 Expression::symbol(x.clone()),
407 Expression::symbol(p.clone()),
408 ]);
409
410 let result = expr.subs(
411 &Expression::symbol(p.clone()),
412 &Expression::symbol(h.clone()),
413 );
414
415 let expected = Expression::mul(vec![
416 Expression::symbol(h.clone()),
417 Expression::symbol(x.clone()),
418 Expression::symbol(h.clone()),
419 ]);
420
421 assert_eq!(
422 result, expected,
423 "Substitution p->H in pxp must preserve positions to get HxH"
424 );
425 }
426
427 #[test]
428 fn test_substitution_multiple_occurrences_different_positions() {
429 let a = symbol!(A; matrix);
430 let b = symbol!(B; matrix);
431 let c = symbol!(C; matrix);
432 let d = symbol!(D; matrix);
433
434 let expr = Expression::mul(vec![
435 Expression::symbol(a.clone()),
436 Expression::symbol(b.clone()),
437 Expression::symbol(c.clone()),
438 Expression::symbol(a.clone()),
439 ])
440 .simplify();
441
442 let result = expr.subs(
443 &Expression::symbol(a.clone()),
444 &Expression::symbol(d.clone()),
445 );
446
447 let expected = Expression::mul(vec![
448 Expression::symbol(d.clone()),
449 Expression::symbol(b.clone()),
450 Expression::symbol(c.clone()),
451 Expression::symbol(d.clone()),
452 ])
453 .simplify();
454
455 assert_eq!(
456 result, expected,
457 "Substitution A->D in ABCA must preserve all positions to get DBCD"
458 );
459 }
460
461 #[test]
462 fn test_substitution_quaternions_position_matters() {
463 let i = symbol!(i; quaternion);
464 let j = symbol!(j; quaternion);
465 let k = symbol!(k; quaternion);
466
467 let expr = Expression::mul(vec![
468 Expression::symbol(i.clone()),
469 Expression::symbol(j.clone()),
470 Expression::symbol(i.clone()),
471 ]);
472
473 let result = expr.subs(
474 &Expression::symbol(i.clone()),
475 &Expression::symbol(k.clone()),
476 );
477
478 let expected = Expression::mul(vec![
479 Expression::symbol(k.clone()),
480 Expression::symbol(j.clone()),
481 Expression::symbol(k.clone()),
482 ]);
483
484 assert_eq!(
485 result, expected,
486 "Substitution i->k in iji must preserve positions to get kjk"
487 );
488 }
489
490 #[test]
491 fn test_substitution_scalars_commutative_still_preserves_structure() {
492 let x = symbol!(x);
493 let y = symbol!(y);
494 let z = symbol!(z);
495
496 let expr = Expression::mul(vec![
497 Expression::symbol(x.clone()),
498 Expression::symbol(y.clone()),
499 Expression::symbol(x.clone()),
500 ]);
501
502 let result = expr.subs(
503 &Expression::symbol(x.clone()),
504 &Expression::symbol(z.clone()),
505 );
506
507 let expected = Expression::mul(vec![
508 Expression::symbol(z.clone()),
509 Expression::symbol(y.clone()),
510 Expression::symbol(z.clone()),
511 ]);
512
513 assert_eq!(
514 result, expected,
515 "Substitution x->z in xyx preserves structure even for commutative scalars"
516 );
517 }
518}