mathhook_core/core/expression/methods/
analysis.rs

1//! Expression analysis methods
2//!
3//! This module provides methods for analyzing properties of expressions,
4//! including commutativity analysis and variable occurrence counting.
5
6use super::super::Expression;
7use crate::core::commutativity::Commutativity;
8use crate::core::Symbol;
9
10impl Expression {
11    /// Compute commutativity of this expression
12    ///
13    /// Commutativity is inferred from the symbols and operations:
14    /// - Numbers, constants: Commutative
15    /// - Symbols: Depends on symbol type (Scalar -> Commutative, Matrix/Operator/Quaternion -> Noncommutative)
16    /// - Mul: Noncommutative if ANY factor is noncommutative
17    /// - Add, Pow, Function: Depends on subexpressions
18    ///
19    /// # Examples
20    ///
21    /// Basic scalar symbols (commutative):
22    /// ```
23    /// use mathhook_core::core::symbol::Symbol;
24    /// use mathhook_core::core::expression::Expression;
25    /// use mathhook_core::core::commutativity::Commutativity;
26    ///
27    /// let x = Symbol::scalar("x");
28    /// let y = Symbol::scalar("y");
29    /// let expr = Expression::mul(vec![
30    ///     Expression::symbol(x.clone()),
31    ///     Expression::symbol(y.clone()),
32    /// ]);
33    /// assert_eq!(expr.commutativity(), Commutativity::Commutative);
34    /// ```
35    ///
36    /// Matrix symbols (noncommutative):
37    /// ```
38    /// use mathhook_core::core::symbol::Symbol;
39    /// use mathhook_core::core::expression::Expression;
40    /// use mathhook_core::core::commutativity::Commutativity;
41    ///
42    /// let a = Symbol::matrix("A");
43    /// let b = Symbol::matrix("B");
44    /// let expr = Expression::mul(vec![
45    ///     Expression::symbol(a.clone()),
46    ///     Expression::symbol(b.clone()),
47    /// ]);
48    /// assert_eq!(expr.commutativity(), Commutativity::Noncommutative);
49    /// ```
50    pub fn commutativity(&self) -> Commutativity {
51        match self {
52            Expression::Symbol(s) => s.commutativity(),
53            Expression::Number(_) => Commutativity::Commutative,
54            Expression::Constant(_) => Commutativity::Commutative,
55
56            Expression::Add(terms) => {
57                Commutativity::combine(terms.iter().map(|t| t.commutativity()))
58            }
59
60            Expression::Mul(factors) => {
61                Commutativity::combine(factors.iter().map(|f| f.commutativity()))
62            }
63
64            Expression::Pow(base, _exp) => base.commutativity(),
65
66            Expression::Function { args, .. } => {
67                Commutativity::combine(args.iter().map(|a| a.commutativity()))
68            }
69
70            Expression::Set(elements) => {
71                Commutativity::combine(elements.iter().map(|e| e.commutativity()))
72            }
73
74            Expression::Complex(data) => {
75                let real_comm = data.real.commutativity();
76                let imag_comm = data.imag.commutativity();
77                Commutativity::combine([real_comm, imag_comm])
78            }
79
80            Expression::Matrix(_) => Commutativity::Noncommutative,
81
82            Expression::Relation(data) => {
83                let left_comm = data.left.commutativity();
84                let right_comm = data.right.commutativity();
85                Commutativity::combine([left_comm, right_comm])
86            }
87
88            Expression::Piecewise(data) => {
89                let piece_comms = data
90                    .pieces
91                    .iter()
92                    .flat_map(|(expr, cond)| [expr.commutativity(), cond.commutativity()]);
93                let default_comm = data.default.as_ref().map(|e| e.commutativity()).into_iter();
94                Commutativity::combine(piece_comms.chain(default_comm))
95            }
96
97            Expression::Interval(data) => {
98                let start_comm = data.start.commutativity();
99                let end_comm = data.end.commutativity();
100                Commutativity::combine([start_comm, end_comm])
101            }
102
103            Expression::Calculus(data) => match &**data {
104                crate::core::expression::CalculusData::Derivative {
105                    expression,
106                    variable: _,
107                    order: _,
108                } => expression.commutativity(),
109                crate::core::expression::CalculusData::Integral {
110                    integrand,
111                    variable: _,
112                    bounds,
113                } => {
114                    let integrand_comm = integrand.commutativity();
115                    if let Some((lower, upper)) = bounds {
116                        Commutativity::combine([
117                            integrand_comm,
118                            lower.commutativity(),
119                            upper.commutativity(),
120                        ])
121                    } else {
122                        integrand_comm
123                    }
124                }
125                crate::core::expression::CalculusData::Limit {
126                    expression,
127                    variable: _,
128                    point,
129                    direction: _,
130                } => Commutativity::combine([expression.commutativity(), point.commutativity()]),
131                crate::core::expression::CalculusData::Sum {
132                    expression,
133                    variable: _,
134                    start,
135                    end,
136                } => Commutativity::combine([
137                    expression.commutativity(),
138                    start.commutativity(),
139                    end.commutativity(),
140                ]),
141                crate::core::expression::CalculusData::Product {
142                    expression,
143                    variable: _,
144                    start,
145                    end,
146                } => Commutativity::combine([
147                    expression.commutativity(),
148                    start.commutativity(),
149                    end.commutativity(),
150                ]),
151            },
152
153            Expression::MethodCall(data) => {
154                let object_comm = data.object.commutativity();
155                let args_comm = data.args.iter().map(|a| a.commutativity());
156                Commutativity::combine([object_comm].into_iter().chain(args_comm))
157            }
158        }
159    }
160
161    /// Count occurrences of a variable in the expression
162    ///
163    /// Recursively counts how many times a specific variable symbol appears
164    /// in the expression tree. This is useful for:
165    /// - Determining if an expression is polynomial in a variable
166    /// - Analyzing variable dependencies
167    /// - Checking if a variable appears in an equation
168    ///
169    /// # Arguments
170    ///
171    /// * `variable` - The symbol to count occurrences of
172    ///
173    /// # Returns
174    ///
175    /// The number of times the variable appears in the expression
176    ///
177    /// # Examples
178    ///
179    /// Basic counting in simple expressions:
180    /// ```
181    /// use mathhook_core::{Expression, symbol};
182    ///
183    /// let x = symbol!(x);
184    /// let expr = Expression::mul(vec![
185    ///     Expression::integer(2),
186    ///     Expression::symbol(x.clone()),
187    /// ]);
188    /// assert_eq!(expr.count_variable_occurrences(&x), 1);
189    /// ```
190    ///
191    /// Counting multiple occurrences:
192    /// ```
193    /// use mathhook_core::{Expression, symbol};
194    /// use std::sync::Arc;
195    ///
196    /// let x = symbol!(x);
197    /// // x^2 + 2*x + 1 has 2 occurrences of x (in x^2 and in 2*x)
198    /// let expr = Expression::Add(Arc::new(vec![
199    ///     Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
200    ///     Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
201    ///     Expression::integer(1),
202    /// ]));
203    /// assert_eq!(expr.count_variable_occurrences(&x), 2);
204    /// ```
205    ///
206    /// Counting in power expressions:
207    /// ```
208    /// use mathhook_core::{Expression, symbol};
209    ///
210    /// let x = symbol!(x);
211    /// // x^x has 2 occurrences (base and exponent)
212    /// let expr = Expression::pow(
213    ///     Expression::symbol(x.clone()),
214    ///     Expression::symbol(x.clone())
215    /// );
216    /// assert_eq!(expr.count_variable_occurrences(&x), 2);
217    /// ```
218    ///
219    /// Counting in functions:
220    /// ```
221    /// use mathhook_core::{Expression, symbol};
222    ///
223    /// let x = symbol!(x);
224    /// // sin(x)
225    /// let expr = Expression::function("sin", vec![Expression::symbol(x.clone())]);
226    /// assert_eq!(expr.count_variable_occurrences(&x), 1);
227    ///
228    /// // f(x, x, 2) has 2 occurrences
229    /// let expr2 = Expression::function("f", vec![
230    ///     Expression::symbol(x.clone()),
231    ///     Expression::symbol(x.clone()),
232    ///     Expression::integer(2),
233    /// ]);
234    /// assert_eq!(expr2.count_variable_occurrences(&x), 2);
235    /// ```
236    ///
237    /// Zero occurrences when variable is not present:
238    /// ```
239    /// use mathhook_core::{Expression, symbol};
240    ///
241    /// let x = symbol!(x);
242    /// let y = symbol!(y);
243    /// let expr = Expression::symbol(y.clone());
244    /// assert_eq!(expr.count_variable_occurrences(&x), 0);
245    /// ```
246    pub fn count_variable_occurrences(&self, variable: &Symbol) -> usize {
247        match self {
248            Expression::Symbol(s) if s == variable => 1,
249            Expression::Symbol(_) | Expression::Number(_) | Expression::Constant(_) => 0,
250
251            Expression::Add(terms) | Expression::Mul(terms) | Expression::Set(terms) => terms
252                .iter()
253                .map(|t| t.count_variable_occurrences(variable))
254                .sum(),
255
256            Expression::Pow(base, exp) => {
257                base.count_variable_occurrences(variable) + exp.count_variable_occurrences(variable)
258            }
259
260            Expression::Function { args, .. } => args
261                .iter()
262                .map(|a| a.count_variable_occurrences(variable))
263                .sum(),
264
265            Expression::Complex(data) => {
266                data.real.count_variable_occurrences(variable)
267                    + data.imag.count_variable_occurrences(variable)
268            }
269
270            Expression::Matrix(matrix) => {
271                let (rows, cols) = matrix.dimensions();
272                let mut count = 0;
273                for i in 0..rows {
274                    for j in 0..cols {
275                        count += matrix
276                            .get_element(i, j)
277                            .count_variable_occurrences(variable);
278                    }
279                }
280                count
281            }
282
283            Expression::Relation(data) => {
284                data.left.count_variable_occurrences(variable)
285                    + data.right.count_variable_occurrences(variable)
286            }
287
288            Expression::Piecewise(data) => {
289                let pieces_count: usize = data
290                    .pieces
291                    .iter()
292                    .map(|(expr, cond)| {
293                        expr.count_variable_occurrences(variable)
294                            + cond.count_variable_occurrences(variable)
295                    })
296                    .sum();
297                let default_count = data
298                    .default
299                    .as_ref()
300                    .map_or(0, |e| e.count_variable_occurrences(variable));
301                pieces_count + default_count
302            }
303
304            Expression::Interval(data) => {
305                data.start.count_variable_occurrences(variable)
306                    + data.end.count_variable_occurrences(variable)
307            }
308
309            Expression::Calculus(data) => match data.as_ref() {
310                crate::core::expression::data_types::CalculusData::Derivative {
311                    expression,
312                    variable: v,
313                    ..
314                } => {
315                    expression.count_variable_occurrences(variable)
316                        + if v == variable { 1 } else { 0 }
317                }
318                crate::core::expression::data_types::CalculusData::Integral {
319                    integrand,
320                    variable: v,
321                    bounds,
322                } => {
323                    let integrand_count = integrand.count_variable_occurrences(variable);
324                    let var_count = if v == variable { 1 } else { 0 };
325                    let bounds_count = bounds.as_ref().map_or(0, |(lower, upper)| {
326                        lower.count_variable_occurrences(variable)
327                            + upper.count_variable_occurrences(variable)
328                    });
329                    integrand_count + var_count + bounds_count
330                }
331                crate::core::expression::data_types::CalculusData::Limit {
332                    expression,
333                    variable: v,
334                    point,
335                    ..
336                } => {
337                    expression.count_variable_occurrences(variable)
338                        + if v == variable { 1 } else { 0 }
339                        + point.count_variable_occurrences(variable)
340                }
341                crate::core::expression::data_types::CalculusData::Sum {
342                    expression,
343                    variable: v,
344                    start,
345                    end,
346                }
347                | crate::core::expression::data_types::CalculusData::Product {
348                    expression,
349                    variable: v,
350                    start,
351                    end,
352                } => {
353                    expression.count_variable_occurrences(variable)
354                        + if v == variable { 1 } else { 0 }
355                        + start.count_variable_occurrences(variable)
356                        + end.count_variable_occurrences(variable)
357                }
358            },
359
360            Expression::MethodCall(data) => {
361                data.object.count_variable_occurrences(variable)
362                    + data
363                        .args
364                        .iter()
365                        .map(|a| a.count_variable_occurrences(variable))
366                        .sum::<usize>()
367            }
368        }
369    }
370
371    pub fn contains_variable(&self, symbol: &Symbol) -> bool {
372        self.count_variable_occurrences(symbol) > 0
373    }
374
375    /// Check if expression is just the variable itself
376    pub fn is_simple_variable(&self, var: &Symbol) -> bool {
377        matches!(self, Expression::Symbol(s) if s == var)
378    }
379
380    /// Check if this expression is a specific symbol
381    ///
382    /// Convenience method for pattern matching against a specific symbol.
383    /// More readable than inline matches! pattern in complex conditions.
384    ///
385    /// # Arguments
386    ///
387    /// * `symbol` - The symbol to check against
388    ///
389    /// # Returns
390    ///
391    /// True if this expression is exactly the given symbol
392    ///
393    /// # Examples
394    ///
395    /// ```
396    /// use mathhook_core::{Expression, symbol};
397    ///
398    /// let x = symbol!(x);
399    /// let y = symbol!(y);
400    /// let expr = Expression::symbol(x.clone());
401    ///
402    /// assert!(expr.is_symbol_matching(&x));
403    /// assert!(!expr.is_symbol_matching(&y));
404    /// ```
405    pub fn is_symbol_matching(&self, symbol: &Symbol) -> bool {
406        matches!(self, Expression::Symbol(s) if s == symbol)
407    }
408
409    /// Extract the base and exponent from a Pow expression
410    ///
411    /// Returns Some((base, exp)) if this is a Pow expression, None otherwise.
412    /// This is a helper method for pattern matching with the Arc-based structure.
413    #[inline]
414    pub fn as_pow(&self) -> Option<(&Expression, &Expression)> {
415        match self {
416            Expression::Pow(base, exp) => Some((base.as_ref(), exp.as_ref())),
417            _ => None,
418        }
419    }
420
421    /// Extract the name and args from a Function expression
422    ///
423    /// Returns Some((name, args)) if this is a Function expression, None otherwise.
424    /// This is a helper method for pattern matching with the Arc-based structure.
425    #[inline]
426    pub fn as_function(&self) -> Option<(&str, &[Expression])> {
427        match self {
428            Expression::Function { name, args } => Some((name.as_ref(), args.as_slice())),
429            _ => None,
430        }
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437    use crate::core::expression::data_types::{
438        CalculusData, ComplexData, PiecewiseData, RelationData, RelationType,
439    };
440    use crate::expr;
441    use crate::matrices::unified::Matrix;
442    use crate::symbol;
443    use std::sync::Arc;
444
445    #[test]
446    fn test_commutativity_scalar_multiplication() {
447        let x = Symbol::scalar("x");
448        let y = Symbol::scalar("y");
449        let expr = Expression::mul(vec![
450            Expression::symbol(x.clone()),
451            Expression::symbol(y.clone()),
452        ]);
453        assert_eq!(expr.commutativity(), Commutativity::Commutative);
454    }
455
456    #[test]
457    fn test_commutativity_matrix_multiplication() {
458        let a = Symbol::matrix("A");
459        let b = Symbol::matrix("B");
460        let expr = Expression::mul(vec![
461            Expression::symbol(a.clone()),
462            Expression::symbol(b.clone()),
463        ]);
464        assert_eq!(expr.commutativity(), Commutativity::Noncommutative);
465    }
466
467    #[test]
468    fn test_count_in_symbol() {
469        let x = symbol!(x);
470        let expr = Expression::symbol(x.clone());
471        assert_eq!(expr.count_variable_occurrences(&x), 1);
472
473        let y = symbol!(y);
474        assert_eq!(expr.count_variable_occurrences(&y), 0);
475    }
476
477    #[test]
478    fn test_count_in_add() {
479        let x = symbol!(x);
480        let y = symbol!(y);
481        let raw_expr = Expression::Add(Arc::new(vec![
482            Expression::symbol(x.clone()),
483            Expression::symbol(x.clone()),
484            Expression::symbol(y.clone()),
485        ]));
486        assert_eq!(raw_expr.count_variable_occurrences(&x), 2);
487        assert_eq!(raw_expr.count_variable_occurrences(&y), 1);
488    }
489
490    #[test]
491    fn test_count_in_pow() {
492        let x = symbol!(x);
493        let expr = Expression::pow(Expression::symbol(x.clone()), expr!(2));
494        assert_eq!(expr.count_variable_occurrences(&x), 1);
495
496        let expr2 = Expression::pow(Expression::symbol(x.clone()), Expression::symbol(x.clone()));
497        assert_eq!(expr2.count_variable_occurrences(&x), 2);
498    }
499
500    #[test]
501    fn test_count_in_function() {
502        let x = symbol!(x);
503        let expr = Expression::function("sin", vec![Expression::symbol(x.clone())]);
504        assert_eq!(expr.count_variable_occurrences(&x), 1);
505
506        let expr2 = Expression::function(
507            "f",
508            vec![
509                Expression::symbol(x.clone()),
510                Expression::symbol(x.clone()),
511                expr!(2),
512            ],
513        );
514        assert_eq!(expr2.count_variable_occurrences(&x), 2);
515    }
516
517    #[test]
518    fn test_count_in_matrix() {
519        let x = symbol!(x);
520        let y = symbol!(y);
521        let matrix = Matrix::dense(vec![
522            vec![Expression::symbol(x.clone()), Expression::symbol(y.clone())],
523            vec![Expression::symbol(x.clone()), Expression::integer(1)],
524        ]);
525        let expr = Expression::Matrix(Arc::new(matrix));
526        assert_eq!(expr.count_variable_occurrences(&x), 2);
527        assert_eq!(expr.count_variable_occurrences(&y), 1);
528    }
529
530    #[test]
531    fn test_count_in_complex() {
532        let x = symbol!(x);
533        let expr = Expression::Complex(Arc::new(ComplexData {
534            real: Expression::symbol(x.clone()),
535            imag: Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
536        }));
537        assert_eq!(expr.count_variable_occurrences(&x), 2);
538    }
539
540    #[test]
541    fn test_count_in_relation() {
542        let x = symbol!(x);
543        let expr = Expression::Relation(Arc::new(RelationData {
544            left: Expression::symbol(x.clone()),
545            right: Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
546            relation_type: RelationType::Equal,
547        }));
548        assert_eq!(expr.count_variable_occurrences(&x), 2);
549    }
550
551    #[test]
552    fn test_count_in_piecewise() {
553        let x = symbol!(x);
554        let expr = Expression::Piecewise(Arc::new(PiecewiseData {
555            pieces: vec![
556                (Expression::symbol(x.clone()), Expression::symbol(x.clone())),
557                (Expression::integer(0), Expression::symbol(x.clone())),
558            ],
559            default: Some(Expression::symbol(x.clone())),
560        }));
561        assert_eq!(expr.count_variable_occurrences(&x), 4);
562    }
563
564    #[test]
565    fn test_count_in_integral() {
566        let x = symbol!(x);
567        let expr = Expression::Calculus(Arc::new(CalculusData::Integral {
568            integrand: Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
569            variable: x.clone(),
570            bounds: Some((Expression::integer(0), Expression::symbol(x.clone()))),
571        }));
572        assert_eq!(expr.count_variable_occurrences(&x), 3);
573    }
574
575    #[test]
576    fn test_is_symbol_matching() {
577        let x = symbol!(x);
578        let y = symbol!(y);
579        let expr_x = Expression::symbol(x.clone());
580        let expr_num = Expression::integer(42);
581
582        assert!(expr_x.is_symbol_matching(&x));
583        assert!(!expr_x.is_symbol_matching(&y));
584        assert!(!expr_num.is_symbol_matching(&x));
585    }
586
587    #[test]
588    fn test_as_pow() {
589        let x = symbol!(x);
590        let pow_expr = Expression::pow(Expression::symbol(x.clone()), Expression::integer(2));
591
592        let (base, exp) = pow_expr.as_pow().expect("should be a Pow");
593        assert_eq!(*base, Expression::symbol(x.clone()));
594        assert_eq!(*exp, Expression::integer(2));
595
596        let not_pow = Expression::integer(42);
597        assert!(not_pow.as_pow().is_none());
598    }
599
600    #[test]
601    fn test_as_function() {
602        let x = symbol!(x);
603        let func = Expression::function("sin", vec![Expression::symbol(x.clone())]);
604
605        let (name, args) = func.as_function().expect("should be a Function");
606        assert_eq!(name, "sin");
607        assert_eq!(args.len(), 1);
608        assert_eq!(args[0], Expression::symbol(x.clone()));
609
610        let not_func = Expression::integer(42);
611        assert!(not_func.as_function().is_none());
612    }
613}