mathhook_core/core/expression/methods/
analysis.rs1use super::super::Expression;
7use crate::core::commutativity::Commutativity;
8use crate::core::Symbol;
9
10impl Expression {
11 pub fn commutativity(&self) -> Commutativity {
51 match self {
52 Expression::Symbol(s) => s.commutativity(),
53 Expression::Number(_) => Commutativity::Commutative,
54 Expression::Constant(_) => Commutativity::Commutative,
55
56 Expression::Add(terms) => {
57 Commutativity::combine(terms.iter().map(|t| t.commutativity()))
58 }
59
60 Expression::Mul(factors) => {
61 Commutativity::combine(factors.iter().map(|f| f.commutativity()))
62 }
63
64 Expression::Pow(base, _exp) => base.commutativity(),
65
66 Expression::Function { args, .. } => {
67 Commutativity::combine(args.iter().map(|a| a.commutativity()))
68 }
69
70 Expression::Set(elements) => {
71 Commutativity::combine(elements.iter().map(|e| e.commutativity()))
72 }
73
74 Expression::Complex(data) => {
75 let real_comm = data.real.commutativity();
76 let imag_comm = data.imag.commutativity();
77 Commutativity::combine([real_comm, imag_comm])
78 }
79
80 Expression::Matrix(_) => Commutativity::Noncommutative,
81
82 Expression::Relation(data) => {
83 let left_comm = data.left.commutativity();
84 let right_comm = data.right.commutativity();
85 Commutativity::combine([left_comm, right_comm])
86 }
87
88 Expression::Piecewise(data) => {
89 let piece_comms = data
90 .pieces
91 .iter()
92 .flat_map(|(expr, cond)| [expr.commutativity(), cond.commutativity()]);
93 let default_comm = data.default.as_ref().map(|e| e.commutativity()).into_iter();
94 Commutativity::combine(piece_comms.chain(default_comm))
95 }
96
97 Expression::Interval(data) => {
98 let start_comm = data.start.commutativity();
99 let end_comm = data.end.commutativity();
100 Commutativity::combine([start_comm, end_comm])
101 }
102
103 Expression::Calculus(data) => match &**data {
104 crate::core::expression::CalculusData::Derivative {
105 expression,
106 variable: _,
107 order: _,
108 } => expression.commutativity(),
109 crate::core::expression::CalculusData::Integral {
110 integrand,
111 variable: _,
112 bounds,
113 } => {
114 let integrand_comm = integrand.commutativity();
115 if let Some((lower, upper)) = bounds {
116 Commutativity::combine([
117 integrand_comm,
118 lower.commutativity(),
119 upper.commutativity(),
120 ])
121 } else {
122 integrand_comm
123 }
124 }
125 crate::core::expression::CalculusData::Limit {
126 expression,
127 variable: _,
128 point,
129 direction: _,
130 } => Commutativity::combine([expression.commutativity(), point.commutativity()]),
131 crate::core::expression::CalculusData::Sum {
132 expression,
133 variable: _,
134 start,
135 end,
136 } => Commutativity::combine([
137 expression.commutativity(),
138 start.commutativity(),
139 end.commutativity(),
140 ]),
141 crate::core::expression::CalculusData::Product {
142 expression,
143 variable: _,
144 start,
145 end,
146 } => Commutativity::combine([
147 expression.commutativity(),
148 start.commutativity(),
149 end.commutativity(),
150 ]),
151 },
152
153 Expression::MethodCall(data) => {
154 let object_comm = data.object.commutativity();
155 let args_comm = data.args.iter().map(|a| a.commutativity());
156 Commutativity::combine([object_comm].into_iter().chain(args_comm))
157 }
158 }
159 }
160
161 pub fn count_variable_occurrences(&self, variable: &Symbol) -> usize {
247 match self {
248 Expression::Symbol(s) if s == variable => 1,
249 Expression::Symbol(_) | Expression::Number(_) | Expression::Constant(_) => 0,
250
251 Expression::Add(terms) | Expression::Mul(terms) | Expression::Set(terms) => terms
252 .iter()
253 .map(|t| t.count_variable_occurrences(variable))
254 .sum(),
255
256 Expression::Pow(base, exp) => {
257 base.count_variable_occurrences(variable) + exp.count_variable_occurrences(variable)
258 }
259
260 Expression::Function { args, .. } => args
261 .iter()
262 .map(|a| a.count_variable_occurrences(variable))
263 .sum(),
264
265 Expression::Complex(data) => {
266 data.real.count_variable_occurrences(variable)
267 + data.imag.count_variable_occurrences(variable)
268 }
269
270 Expression::Matrix(matrix) => {
271 let (rows, cols) = matrix.dimensions();
272 let mut count = 0;
273 for i in 0..rows {
274 for j in 0..cols {
275 count += matrix
276 .get_element(i, j)
277 .count_variable_occurrences(variable);
278 }
279 }
280 count
281 }
282
283 Expression::Relation(data) => {
284 data.left.count_variable_occurrences(variable)
285 + data.right.count_variable_occurrences(variable)
286 }
287
288 Expression::Piecewise(data) => {
289 let pieces_count: usize = data
290 .pieces
291 .iter()
292 .map(|(expr, cond)| {
293 expr.count_variable_occurrences(variable)
294 + cond.count_variable_occurrences(variable)
295 })
296 .sum();
297 let default_count = data
298 .default
299 .as_ref()
300 .map_or(0, |e| e.count_variable_occurrences(variable));
301 pieces_count + default_count
302 }
303
304 Expression::Interval(data) => {
305 data.start.count_variable_occurrences(variable)
306 + data.end.count_variable_occurrences(variable)
307 }
308
309 Expression::Calculus(data) => match data.as_ref() {
310 crate::core::expression::data_types::CalculusData::Derivative {
311 expression,
312 variable: v,
313 ..
314 } => {
315 expression.count_variable_occurrences(variable)
316 + if v == variable { 1 } else { 0 }
317 }
318 crate::core::expression::data_types::CalculusData::Integral {
319 integrand,
320 variable: v,
321 bounds,
322 } => {
323 let integrand_count = integrand.count_variable_occurrences(variable);
324 let var_count = if v == variable { 1 } else { 0 };
325 let bounds_count = bounds.as_ref().map_or(0, |(lower, upper)| {
326 lower.count_variable_occurrences(variable)
327 + upper.count_variable_occurrences(variable)
328 });
329 integrand_count + var_count + bounds_count
330 }
331 crate::core::expression::data_types::CalculusData::Limit {
332 expression,
333 variable: v,
334 point,
335 ..
336 } => {
337 expression.count_variable_occurrences(variable)
338 + if v == variable { 1 } else { 0 }
339 + point.count_variable_occurrences(variable)
340 }
341 crate::core::expression::data_types::CalculusData::Sum {
342 expression,
343 variable: v,
344 start,
345 end,
346 }
347 | crate::core::expression::data_types::CalculusData::Product {
348 expression,
349 variable: v,
350 start,
351 end,
352 } => {
353 expression.count_variable_occurrences(variable)
354 + if v == variable { 1 } else { 0 }
355 + start.count_variable_occurrences(variable)
356 + end.count_variable_occurrences(variable)
357 }
358 },
359
360 Expression::MethodCall(data) => {
361 data.object.count_variable_occurrences(variable)
362 + data
363 .args
364 .iter()
365 .map(|a| a.count_variable_occurrences(variable))
366 .sum::<usize>()
367 }
368 }
369 }
370
371 pub fn contains_variable(&self, symbol: &Symbol) -> bool {
372 self.count_variable_occurrences(symbol) > 0
373 }
374
375 pub fn is_simple_variable(&self, var: &Symbol) -> bool {
377 matches!(self, Expression::Symbol(s) if s == var)
378 }
379
380 pub fn is_symbol_matching(&self, symbol: &Symbol) -> bool {
406 matches!(self, Expression::Symbol(s) if s == symbol)
407 }
408
409 #[inline]
414 pub fn as_pow(&self) -> Option<(&Expression, &Expression)> {
415 match self {
416 Expression::Pow(base, exp) => Some((base.as_ref(), exp.as_ref())),
417 _ => None,
418 }
419 }
420
421 #[inline]
426 pub fn as_function(&self) -> Option<(&str, &[Expression])> {
427 match self {
428 Expression::Function { name, args } => Some((name.as_ref(), args.as_slice())),
429 _ => None,
430 }
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437 use crate::core::expression::data_types::{
438 CalculusData, ComplexData, PiecewiseData, RelationData, RelationType,
439 };
440 use crate::expr;
441 use crate::matrices::unified::Matrix;
442 use crate::symbol;
443 use std::sync::Arc;
444
445 #[test]
446 fn test_commutativity_scalar_multiplication() {
447 let x = Symbol::scalar("x");
448 let y = Symbol::scalar("y");
449 let expr = Expression::mul(vec![
450 Expression::symbol(x.clone()),
451 Expression::symbol(y.clone()),
452 ]);
453 assert_eq!(expr.commutativity(), Commutativity::Commutative);
454 }
455
456 #[test]
457 fn test_commutativity_matrix_multiplication() {
458 let a = Symbol::matrix("A");
459 let b = Symbol::matrix("B");
460 let expr = Expression::mul(vec![
461 Expression::symbol(a.clone()),
462 Expression::symbol(b.clone()),
463 ]);
464 assert_eq!(expr.commutativity(), Commutativity::Noncommutative);
465 }
466
467 #[test]
468 fn test_count_in_symbol() {
469 let x = symbol!(x);
470 let expr = Expression::symbol(x.clone());
471 assert_eq!(expr.count_variable_occurrences(&x), 1);
472
473 let y = symbol!(y);
474 assert_eq!(expr.count_variable_occurrences(&y), 0);
475 }
476
477 #[test]
478 fn test_count_in_add() {
479 let x = symbol!(x);
480 let y = symbol!(y);
481 let raw_expr = Expression::Add(Arc::new(vec![
482 Expression::symbol(x.clone()),
483 Expression::symbol(x.clone()),
484 Expression::symbol(y.clone()),
485 ]));
486 assert_eq!(raw_expr.count_variable_occurrences(&x), 2);
487 assert_eq!(raw_expr.count_variable_occurrences(&y), 1);
488 }
489
490 #[test]
491 fn test_count_in_pow() {
492 let x = symbol!(x);
493 let expr = Expression::pow(Expression::symbol(x.clone()), expr!(2));
494 assert_eq!(expr.count_variable_occurrences(&x), 1);
495
496 let expr2 = Expression::pow(Expression::symbol(x.clone()), Expression::symbol(x.clone()));
497 assert_eq!(expr2.count_variable_occurrences(&x), 2);
498 }
499
500 #[test]
501 fn test_count_in_function() {
502 let x = symbol!(x);
503 let expr = Expression::function("sin", vec![Expression::symbol(x.clone())]);
504 assert_eq!(expr.count_variable_occurrences(&x), 1);
505
506 let expr2 = Expression::function(
507 "f",
508 vec![
509 Expression::symbol(x.clone()),
510 Expression::symbol(x.clone()),
511 expr!(2),
512 ],
513 );
514 assert_eq!(expr2.count_variable_occurrences(&x), 2);
515 }
516
517 #[test]
518 fn test_count_in_matrix() {
519 let x = symbol!(x);
520 let y = symbol!(y);
521 let matrix = Matrix::dense(vec![
522 vec![Expression::symbol(x.clone()), Expression::symbol(y.clone())],
523 vec![Expression::symbol(x.clone()), Expression::integer(1)],
524 ]);
525 let expr = Expression::Matrix(Arc::new(matrix));
526 assert_eq!(expr.count_variable_occurrences(&x), 2);
527 assert_eq!(expr.count_variable_occurrences(&y), 1);
528 }
529
530 #[test]
531 fn test_count_in_complex() {
532 let x = symbol!(x);
533 let expr = Expression::Complex(Arc::new(ComplexData {
534 real: Expression::symbol(x.clone()),
535 imag: Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
536 }));
537 assert_eq!(expr.count_variable_occurrences(&x), 2);
538 }
539
540 #[test]
541 fn test_count_in_relation() {
542 let x = symbol!(x);
543 let expr = Expression::Relation(Arc::new(RelationData {
544 left: Expression::symbol(x.clone()),
545 right: Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
546 relation_type: RelationType::Equal,
547 }));
548 assert_eq!(expr.count_variable_occurrences(&x), 2);
549 }
550
551 #[test]
552 fn test_count_in_piecewise() {
553 let x = symbol!(x);
554 let expr = Expression::Piecewise(Arc::new(PiecewiseData {
555 pieces: vec![
556 (Expression::symbol(x.clone()), Expression::symbol(x.clone())),
557 (Expression::integer(0), Expression::symbol(x.clone())),
558 ],
559 default: Some(Expression::symbol(x.clone())),
560 }));
561 assert_eq!(expr.count_variable_occurrences(&x), 4);
562 }
563
564 #[test]
565 fn test_count_in_integral() {
566 let x = symbol!(x);
567 let expr = Expression::Calculus(Arc::new(CalculusData::Integral {
568 integrand: Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
569 variable: x.clone(),
570 bounds: Some((Expression::integer(0), Expression::symbol(x.clone()))),
571 }));
572 assert_eq!(expr.count_variable_occurrences(&x), 3);
573 }
574
575 #[test]
576 fn test_is_symbol_matching() {
577 let x = symbol!(x);
578 let y = symbol!(y);
579 let expr_x = Expression::symbol(x.clone());
580 let expr_num = Expression::integer(42);
581
582 assert!(expr_x.is_symbol_matching(&x));
583 assert!(!expr_x.is_symbol_matching(&y));
584 assert!(!expr_num.is_symbol_matching(&x));
585 }
586
587 #[test]
588 fn test_as_pow() {
589 let x = symbol!(x);
590 let pow_expr = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
591
592 let (base, exp) = pow_expr.as_pow().expect("should be a Pow");
593 assert_eq!(*base, Expression::symbol(x.clone()));
594 assert_eq!(*exp, Expression::integer(2));
595
596 let not_pow = Expression::integer(42);
597 assert!(not_pow.as_pow().is_none());
598 }
599
600 #[test]
601 fn test_as_function() {
602 let x = symbol!(x);
603 let func = Expression::function("sin", vec![Expression::symbol(x.clone())]);
604
605 let (name, args) = func.as_function().expect("should be a Function");
606 assert_eq!(name, "sin");
607 assert_eq!(args.len(), 1);
608 assert_eq!(args[0], Expression::symbol(x.clone()));
609
610 let not_func = Expression::integer(42);
611 assert!(not_func.as_function().is_none());
612 }
613}