mathhook_core/core/expression/methods/
analysis.rs

1//! Expression analysis methods
2//!
3//! This module provides methods for analyzing properties of expressions,
4//! including commutativity analysis and variable occurrence counting.
5
6use super::super::Expression;
7use crate::core::commutativity::Commutativity;
8use crate::core::Symbol;
9
10impl Expression {
11    /// Compute commutativity of this expression
12    ///
13    /// Commutativity is inferred from the symbols and operations:
14    /// - Numbers, constants: Commutative
15    /// - Symbols: Depends on symbol type (Scalar → Commutative, Matrix/Operator/Quaternion → Noncommutative)
16    /// - Mul: Noncommutative if ANY factor is noncommutative
17    /// - Add, Pow, Function: Depends on subexpressions
18    ///
19    /// # Examples
20    ///
21    /// Basic scalar symbols (commutative):
22    /// ```
23    /// use mathhook_core::core::symbol::Symbol;
24    /// use mathhook_core::core::expression::Expression;
25    /// use mathhook_core::core::commutativity::Commutativity;
26    ///
27    /// let x = Symbol::scalar("x");
28    /// let y = Symbol::scalar("y");
29    /// let expr = Expression::mul(vec![
30    ///     Expression::symbol(x.clone()),
31    ///     Expression::symbol(y.clone()),
32    /// ]);
33    /// assert_eq!(expr.commutativity(), Commutativity::Commutative);
34    /// ```
35    ///
36    /// Matrix symbols (noncommutative):
37    /// ```
38    /// use mathhook_core::core::symbol::Symbol;
39    /// use mathhook_core::core::expression::Expression;
40    /// use mathhook_core::core::commutativity::Commutativity;
41    ///
42    /// let a = Symbol::matrix("A");
43    /// let b = Symbol::matrix("B");
44    /// let expr = Expression::mul(vec![
45    ///     Expression::symbol(a.clone()),
46    ///     Expression::symbol(b.clone()),
47    /// ]);
48    /// assert_eq!(expr.commutativity(), Commutativity::Noncommutative);
49    /// ```
50    pub fn commutativity(&self) -> Commutativity {
51        match self {
52            Expression::Symbol(s) => s.commutativity(),
53            Expression::Number(_) => Commutativity::Commutative,
54            Expression::Constant(_) => Commutativity::Commutative,
55
56            Expression::Add(terms) => {
57                Commutativity::combine(terms.iter().map(|t| t.commutativity()))
58            }
59
60            Expression::Mul(factors) => {
61                Commutativity::combine(factors.iter().map(|f| f.commutativity()))
62            }
63
64            Expression::Pow(base, _exp) => base.commutativity(),
65
66            Expression::Function { args, .. } => {
67                Commutativity::combine(args.iter().map(|a| a.commutativity()))
68            }
69
70            Expression::Set(elements) => {
71                Commutativity::combine(elements.iter().map(|e| e.commutativity()))
72            }
73
74            Expression::Complex(data) => {
75                let real_comm = data.real.commutativity();
76                let imag_comm = data.imag.commutativity();
77                Commutativity::combine([real_comm, imag_comm])
78            }
79
80            Expression::Matrix(_) => Commutativity::Noncommutative,
81
82            Expression::Relation(data) => {
83                let left_comm = data.left.commutativity();
84                let right_comm = data.right.commutativity();
85                Commutativity::combine([left_comm, right_comm])
86            }
87
88            Expression::Piecewise(data) => {
89                let piece_comms = data
90                    .pieces
91                    .iter()
92                    .flat_map(|(expr, cond)| [expr.commutativity(), cond.commutativity()]);
93                let default_comm = data.default.as_ref().map(|e| e.commutativity()).into_iter();
94                Commutativity::combine(piece_comms.chain(default_comm))
95            }
96
97            Expression::Interval(data) => {
98                let start_comm = data.start.commutativity();
99                let end_comm = data.end.commutativity();
100                Commutativity::combine([start_comm, end_comm])
101            }
102
103            Expression::Calculus(data) => match &**data {
104                crate::core::expression::CalculusData::Derivative {
105                    expression,
106                    variable: _,
107                    order: _,
108                } => expression.commutativity(),
109                crate::core::expression::CalculusData::Integral {
110                    integrand,
111                    variable: _,
112                    bounds,
113                } => {
114                    let integrand_comm = integrand.commutativity();
115                    if let Some((lower, upper)) = bounds {
116                        Commutativity::combine([
117                            integrand_comm,
118                            lower.commutativity(),
119                            upper.commutativity(),
120                        ])
121                    } else {
122                        integrand_comm
123                    }
124                }
125                crate::core::expression::CalculusData::Limit {
126                    expression,
127                    variable: _,
128                    point,
129                    direction: _,
130                } => Commutativity::combine([expression.commutativity(), point.commutativity()]),
131                crate::core::expression::CalculusData::Sum {
132                    expression,
133                    variable: _,
134                    start,
135                    end,
136                } => Commutativity::combine([
137                    expression.commutativity(),
138                    start.commutativity(),
139                    end.commutativity(),
140                ]),
141                crate::core::expression::CalculusData::Product {
142                    expression,
143                    variable: _,
144                    start,
145                    end,
146                } => Commutativity::combine([
147                    expression.commutativity(),
148                    start.commutativity(),
149                    end.commutativity(),
150                ]),
151            },
152
153            Expression::MethodCall(data) => {
154                let object_comm = data.object.commutativity();
155                let args_comm = data.args.iter().map(|a| a.commutativity());
156                Commutativity::combine([object_comm].into_iter().chain(args_comm))
157            }
158        }
159    }
160
161    /// Count occurrences of a variable in the expression
162    ///
163    /// Recursively counts how many times a specific variable symbol appears
164    /// in the expression tree. This is useful for:
165    /// - Determining if an expression is polynomial in a variable
166    /// - Analyzing variable dependencies
167    /// - Checking if a variable appears in an equation
168    ///
169    /// # Arguments
170    ///
171    /// * `variable` - The symbol to count occurrences of
172    ///
173    /// # Returns
174    ///
175    /// The number of times the variable appears in the expression
176    ///
177    /// # Examples
178    ///
179    /// Basic counting in simple expressions:
180    /// ```
181    /// use mathhook_core::{Expression, symbol};
182    ///
183    /// let x = symbol!(x);
184    /// let expr = Expression::mul(vec![
185    ///     Expression::integer(2),
186    ///     Expression::symbol(x.clone()),
187    /// ]);
188    /// assert_eq!(expr.count_variable_occurrences(&x), 1);
189    /// ```
190    ///
191    /// Counting multiple occurrences:
192    /// ```
193    /// use mathhook_core::{Expression, symbol};
194    ///
195    /// let x = symbol!(x);
196    /// // x^2 + 2*x + 1 has 2 occurrences of x (in x^2 and in 2*x)
197    /// let expr = Expression::Add(Box::new(vec![
198    ///     Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
199    ///     Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
200    ///     Expression::integer(1),
201    /// ]));
202    /// assert_eq!(expr.count_variable_occurrences(&x), 2);
203    /// ```
204    ///
205    /// Counting in power expressions:
206    /// ```
207    /// use mathhook_core::{Expression, symbol};
208    ///
209    /// let x = symbol!(x);
210    /// // x^x has 2 occurrences (base and exponent)
211    /// let expr = Expression::pow(
212    ///     Expression::symbol(x.clone()),
213    ///     Expression::symbol(x.clone())
214    /// );
215    /// assert_eq!(expr.count_variable_occurrences(&x), 2);
216    /// ```
217    ///
218    /// Counting in functions:
219    /// ```
220    /// use mathhook_core::{Expression, symbol};
221    ///
222    /// let x = symbol!(x);
223    /// // sin(x)
224    /// let expr = Expression::function("sin", vec![Expression::symbol(x.clone())]);
225    /// assert_eq!(expr.count_variable_occurrences(&x), 1);
226    ///
227    /// // f(x, x, 2) has 2 occurrences
228    /// let expr2 = Expression::function("f", vec![
229    ///     Expression::symbol(x.clone()),
230    ///     Expression::symbol(x.clone()),
231    ///     Expression::integer(2),
232    /// ]);
233    /// assert_eq!(expr2.count_variable_occurrences(&x), 2);
234    /// ```
235    ///
236    /// Zero occurrences when variable is not present:
237    /// ```
238    /// use mathhook_core::{Expression, symbol};
239    ///
240    /// let x = symbol!(x);
241    /// let y = symbol!(y);
242    /// let expr = Expression::symbol(y.clone());
243    /// assert_eq!(expr.count_variable_occurrences(&x), 0);
244    /// ```
245    pub fn count_variable_occurrences(&self, variable: &Symbol) -> usize {
246        match self {
247            Expression::Symbol(s) if s == variable => 1,
248            Expression::Symbol(_) | Expression::Number(_) | Expression::Constant(_) => 0,
249
250            Expression::Add(terms) | Expression::Mul(terms) | Expression::Set(terms) => terms
251                .iter()
252                .map(|t| t.count_variable_occurrences(variable))
253                .sum(),
254
255            Expression::Pow(base, exp) => {
256                base.count_variable_occurrences(variable) + exp.count_variable_occurrences(variable)
257            }
258
259            Expression::Function { args, .. } => args
260                .iter()
261                .map(|a| a.count_variable_occurrences(variable))
262                .sum(),
263
264            Expression::Complex(data) => {
265                data.real.count_variable_occurrences(variable)
266                    + data.imag.count_variable_occurrences(variable)
267            }
268
269            Expression::Matrix(matrix) => {
270                let (rows, cols) = matrix.dimensions();
271                let mut count = 0;
272                for i in 0..rows {
273                    for j in 0..cols {
274                        count += matrix
275                            .get_element(i, j)
276                            .count_variable_occurrences(variable);
277                    }
278                }
279                count
280            }
281
282            Expression::Relation(data) => {
283                data.left.count_variable_occurrences(variable)
284                    + data.right.count_variable_occurrences(variable)
285            }
286
287            Expression::Piecewise(data) => {
288                let pieces_count: usize = data
289                    .pieces
290                    .iter()
291                    .map(|(expr, cond)| {
292                        expr.count_variable_occurrences(variable)
293                            + cond.count_variable_occurrences(variable)
294                    })
295                    .sum();
296                let default_count = data
297                    .default
298                    .as_ref()
299                    .map_or(0, |e| e.count_variable_occurrences(variable));
300                pieces_count + default_count
301            }
302
303            Expression::Interval(data) => {
304                data.start.count_variable_occurrences(variable)
305                    + data.end.count_variable_occurrences(variable)
306            }
307
308            Expression::Calculus(data) => match data.as_ref() {
309                crate::core::expression::data_types::CalculusData::Derivative {
310                    expression,
311                    variable: v,
312                    ..
313                } => {
314                    expression.count_variable_occurrences(variable)
315                        + if v == variable { 1 } else { 0 }
316                }
317                crate::core::expression::data_types::CalculusData::Integral {
318                    integrand,
319                    variable: v,
320                    bounds,
321                } => {
322                    let integrand_count = integrand.count_variable_occurrences(variable);
323                    let var_count = if v == variable { 1 } else { 0 };
324                    let bounds_count = bounds.as_ref().map_or(0, |(lower, upper)| {
325                        lower.count_variable_occurrences(variable)
326                            + upper.count_variable_occurrences(variable)
327                    });
328                    integrand_count + var_count + bounds_count
329                }
330                crate::core::expression::data_types::CalculusData::Limit {
331                    expression,
332                    variable: v,
333                    point,
334                    ..
335                } => {
336                    expression.count_variable_occurrences(variable)
337                        + if v == variable { 1 } else { 0 }
338                        + point.count_variable_occurrences(variable)
339                }
340                crate::core::expression::data_types::CalculusData::Sum {
341                    expression,
342                    variable: v,
343                    start,
344                    end,
345                }
346                | crate::core::expression::data_types::CalculusData::Product {
347                    expression,
348                    variable: v,
349                    start,
350                    end,
351                } => {
352                    expression.count_variable_occurrences(variable)
353                        + if v == variable { 1 } else { 0 }
354                        + start.count_variable_occurrences(variable)
355                        + end.count_variable_occurrences(variable)
356                }
357            },
358
359            Expression::MethodCall(data) => {
360                data.object.count_variable_occurrences(variable)
361                    + data
362                        .args
363                        .iter()
364                        .map(|a| a.count_variable_occurrences(variable))
365                        .sum::<usize>()
366            }
367        }
368    }
369
370    pub fn contains_variable(&self, symbol: &Symbol) -> bool {
371        self.count_variable_occurrences(symbol) > 0
372    }
373
374    /// Check if expression is just the variable itself
375    pub fn is_simple_variable(&self, var: &Symbol) -> bool {
376        matches!(self, Expression::Symbol(s) if s == var)
377    }
378
379    /// Check if this expression is a specific symbol
380    ///
381    /// Convenience method for pattern matching against a specific symbol.
382    /// More readable than inline matches! pattern in complex conditions.
383    ///
384    /// # Arguments
385    ///
386    /// * `symbol` - The symbol to check against
387    ///
388    /// # Returns
389    ///
390    /// True if this expression is exactly the given symbol
391    ///
392    /// # Examples
393    ///
394    /// ```
395    /// use mathhook_core::{Expression, symbol};
396    ///
397    /// let x = symbol!(x);
398    /// let y = symbol!(y);
399    /// let expr = Expression::symbol(x.clone());
400    ///
401    /// assert!(expr.is_symbol_matching(&x));
402    /// assert!(!expr.is_symbol_matching(&y));
403    /// ```
404    pub fn is_symbol_matching(&self, symbol: &Symbol) -> bool {
405        matches!(self, Expression::Symbol(s) if s == symbol)
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412    use crate::core::expression::data_types::{
413        CalculusData, ComplexData, PiecewiseData, RelationData, RelationType,
414    };
415    use crate::expr;
416    use crate::matrices::unified::Matrix;
417    use crate::symbol;
418    #[test]
419    fn test_commutativity_scalar_multiplication() {
420        let x = Symbol::scalar("x");
421        let y = Symbol::scalar("y");
422        let expr = Expression::mul(vec![
423            Expression::symbol(x.clone()),
424            Expression::symbol(y.clone()),
425        ]);
426        assert_eq!(expr.commutativity(), Commutativity::Commutative);
427    }
428
429    #[test]
430    fn test_commutativity_matrix_multiplication() {
431        let a = Symbol::matrix("A");
432        let b = Symbol::matrix("B");
433        let expr = Expression::mul(vec![
434            Expression::symbol(a.clone()),
435            Expression::symbol(b.clone()),
436        ]);
437        assert_eq!(expr.commutativity(), Commutativity::Noncommutative);
438    }
439
440    #[test]
441    fn test_count_in_symbol() {
442        let x = symbol!(x);
443        let expr = Expression::symbol(x.clone());
444        assert_eq!(expr.count_variable_occurrences(&x), 1);
445
446        let y = symbol!(y);
447        assert_eq!(expr.count_variable_occurrences(&y), 0);
448    }
449
450    #[test]
451    fn test_count_in_add() {
452        let x = symbol!(x);
453        let y = symbol!(y);
454        let raw_expr = Expression::Add(Box::new(vec![
455            Expression::symbol(x.clone()),
456            Expression::symbol(x.clone()),
457            Expression::symbol(y.clone()),
458        ]));
459        assert_eq!(raw_expr.count_variable_occurrences(&x), 2);
460        assert_eq!(raw_expr.count_variable_occurrences(&y), 1);
461    }
462
463    #[test]
464    fn test_count_in_pow() {
465        let x = symbol!(x);
466        let expr = Expression::pow(Expression::symbol(x.clone()), expr!(2));
467        assert_eq!(expr.count_variable_occurrences(&x), 1);
468
469        let expr2 = Expression::pow(Expression::symbol(x.clone()), Expression::symbol(x.clone()));
470        assert_eq!(expr2.count_variable_occurrences(&x), 2);
471    }
472
473    #[test]
474    fn test_count_in_function() {
475        let x = symbol!(x);
476        let expr = Expression::function("sin", vec![Expression::symbol(x.clone())]);
477        assert_eq!(expr.count_variable_occurrences(&x), 1);
478
479        let expr2 = Expression::function(
480            "f",
481            vec![
482                Expression::symbol(x.clone()),
483                Expression::symbol(x.clone()),
484                expr!(2),
485            ],
486        );
487        assert_eq!(expr2.count_variable_occurrences(&x), 2);
488    }
489
490    #[test]
491    fn test_count_in_matrix() {
492        let x = symbol!(x);
493        let y = symbol!(y);
494        let matrix = Matrix::dense(vec![
495            vec![Expression::symbol(x.clone()), Expression::symbol(y.clone())],
496            vec![Expression::symbol(x.clone()), Expression::integer(1)],
497        ]);
498        let expr = Expression::Matrix(Box::new(matrix));
499        assert_eq!(expr.count_variable_occurrences(&x), 2);
500        assert_eq!(expr.count_variable_occurrences(&y), 1);
501    }
502
503    #[test]
504    fn test_count_in_complex() {
505        let x = symbol!(x);
506        let expr = Expression::Complex(Box::new(ComplexData {
507            real: Expression::symbol(x.clone()),
508            imag: Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
509        }));
510        assert_eq!(expr.count_variable_occurrences(&x), 2);
511    }
512
513    #[test]
514    fn test_count_in_relation() {
515        let x = symbol!(x);
516        let expr = Expression::Relation(Box::new(RelationData {
517            left: Expression::symbol(x.clone()),
518            right: Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
519            relation_type: RelationType::Equal,
520        }));
521        assert_eq!(expr.count_variable_occurrences(&x), 2);
522    }
523
524    #[test]
525    fn test_count_in_piecewise() {
526        let x = symbol!(x);
527        let expr = Expression::Piecewise(Box::new(PiecewiseData {
528            pieces: vec![
529                (Expression::symbol(x.clone()), Expression::symbol(x.clone())),
530                (Expression::integer(0), Expression::symbol(x.clone())),
531            ],
532            default: Some(Expression::symbol(x.clone())),
533        }));
534        assert_eq!(expr.count_variable_occurrences(&x), 4);
535    }
536
537    #[test]
538    fn test_count_in_integral() {
539        let x = symbol!(x);
540        let expr = Expression::Calculus(Box::new(CalculusData::Integral {
541            integrand: Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
542            variable: x.clone(),
543            bounds: Some((Expression::integer(0), Expression::symbol(x.clone()))),
544        }));
545        assert_eq!(expr.count_variable_occurrences(&x), 3);
546    }
547
548    #[test]
549    fn test_is_symbol_matching() {
550        let x = symbol!(x);
551        let y = symbol!(y);
552        let expr_x = Expression::symbol(x.clone());
553        let expr_num = Expression::integer(42);
554
555        assert!(expr_x.is_symbol_matching(&x));
556        assert!(!expr_x.is_symbol_matching(&y));
557        assert!(!expr_num.is_symbol_matching(&x));
558    }
559}