symrs/expr/ops/
factor.rs

1use std::{
2    cmp::{self, Ordering},
3    collections::HashMap,
4    hash::Hash,
5};
6
7use indexmap::IndexMap;
8use num_traits::NumCast;
9
10use crate::{expr::*, ops::compare};
11
12pub fn factor<'a, E: Expr + ?Sized, F: Expr + ?Sized>(expr: &E, factors: &[&'a F]) -> Box<dyn Expr>
13where
14    &'a F: Hash + cmp::Eq,
15{
16    if let Some(eq) = expr.as_eq() {
17        return Equation::new_box(
18            factor(eq.lhs.get_ref(), factors),
19            factor(eq.rhs.get_ref(), factors),
20        );
21    }
22    let expr = expr.expand();
23    let mut factor_coeffs: HashMap<&F, Box<dyn Expr>> = HashMap::new();
24    let mut others: Vec<Box<dyn Expr>> = Vec::new();
25
26    let operands = match expr.known_expr() {
27        KnownExpr::Add(Add { operands }) => operands.iter().collect(),
28        _ => vec![&expr],
29    };
30
31    // Attempt factorizing by every factor and keeping the first valid factor
32    // If none are found, add to others
33    'ops: for op in operands {
34        for factor in factors {
35            if let Some(coeff) = factor_coeff_no_div(op.get_ref(), factor.get_ref()) {
36                let entry = factor_coeffs.entry(factor).or_insert(Integer::zero_box());
37                *entry += coeff;
38                continue 'ops;
39            }
40        }
41        others.push(op.clone_box());
42    }
43
44    let mut res_operands: Vec<Box<dyn Expr>> =
45        Vec::with_capacity(factor_coeffs.len() + others.len());
46
47    for fact in factors {
48        if let Some(coeff) = factor_coeffs.get(fact) {
49            if !coeff.is_zero() {
50                res_operands.push(factor(coeff.get_ref(), factors) * fact.get_ref());
51            }
52        }
53    }
54    res_operands.extend(others);
55
56    Add::new_box_v2(res_operands)
57}
58
59pub fn factor_coeff<E: Expr + ?Sized, F: Expr + ?Sized>(expr: &E, factor: &F) -> Box<dyn Expr> {
60    expr.get_ref() / factor.get_ref()
61}
62
63pub fn get_operands_exponents<E: Expr + ?Sized>(
64    expr: &E,
65) -> IndexMap<Box<dyn Expr>, Box<dyn Expr>> {
66    let mut operands_exponents: IndexMap<Box<dyn Expr>, Box<dyn Expr>> = IndexMap::new();
67
68    match expr.known_expr() {
69        KnownExpr::Pow(pow) => match pow.base.known_expr() {
70            KnownExpr::Mul(mul) => {
71                for (op, expo) in get_operands_exponents(mul) {
72                    let entry = operands_exponents.entry(op).or_insert(Integer::zero_box());
73                    *entry += expo * pow.exponent.clone_box();
74                }
75            }
76            _ => {
77                let entry = operands_exponents
78                    .entry(pow.base.clone_box())
79                    .or_insert(pow.exponent.clone_box());
80                *entry += Integer::one_box();
81            }
82        },
83        KnownExpr::Mul(mul) => {
84            for op in mul
85                .operands
86                .iter()
87                // Split up factors fo multiplication and powers
88                .flat_map(|op| match op.known_expr() {
89                    KnownExpr::Mul(Mul { operands }) => operands.clone(),
90                    KnownExpr::Pow(Pow { base, exponent })
91                        if matches!(base.known_expr(), KnownExpr::Mul(Mul { .. })) =>
92                    {
93                        let mul = base.as_mul().unwrap();
94                        mul.operands.iter().map(|op| op.pow(exponent)).collect()
95                    }
96                    _ => vec![op.clone_box()],
97                })
98            {
99                let (expr, exponent) = op.get_exponent();
100                let entry = operands_exponents
101                    .entry(expr)
102                    .or_insert(Integer::zero_box());
103                *entry += exponent;
104            }
105        }
106        _ => {
107            operands_exponents.insert(expr.clone_box(), Integer::one_box());
108        }
109    };
110
111    operands_exponents
112}
113
114pub fn factor_coeff_no_div<E: Expr + ?Sized, F: Expr + ?Sized>(
115    expr: &E,
116    factor: &F,
117) -> Option<Box<dyn Expr>> {
118    if expr.get_ref() == factor.get_ref() {
119        return Some(Integer::one_box());
120    }
121    let expr_op_expos = get_operands_exponents(expr);
122    let factor_op_expos = get_operands_exponents(factor);
123
124    for (op, factor_expo) in &factor_op_expos {
125        if let Some(expr_expo) = expr_op_expos.get(op) {
126            if expr_expo.has(factor_expo.get_ref()) {
127                continue;
128            }
129            if let Some(order) = compare(factor_expo.get_ref(), expr_expo.get_ref()) {
130                if order == Ordering::Greater {
131                    return None;
132                }
133            } else {
134                return None;
135            }
136        } else {
137            return None;
138        }
139    }
140    Some(expr.get_ref() / factor.get_ref())
141}
142
143impl<N: NumCast> std::ops::Mul<N> for &dyn Expr {
144    type Output = Box<dyn Expr>;
145
146    fn mul(self, rhs: N) -> Self::Output {
147        let rhs = rhs.to_isize().unwrap();
148
149        Integer::new(rhs).get_ref() * self
150    }
151}
152
153#[macro_export]
154macro_rules! symbol {
155    ($name:expr) => {
156        &Symbol::new($name) as &dyn Expr
157    };
158}
159
160#[macro_export]
161macro_rules! function {
162    ($name:expr) => {
163        &Func::new($name, []) as &dyn Expr
164    };
165}
166
167#[cfg(test)]
168mod tests {
169    use crate::symbols;
170
171    use super::*;
172    #[test]
173    fn test_factor_trivial() {
174        assert_eq!(
175            Symbol::new("x")
176                .factor(&[Symbol::new("x").get_ref()])
177                .srepr(),
178            "Symbol(x)"
179        )
180    }
181
182    #[test]
183    fn test_factor_coeff_simple() {
184        let x = &Symbol::new("x") as &dyn Expr;
185
186        assert_eq!(factor_coeff(&*(x * 2), x).srepr(), "Integer(2)");
187    }
188
189    #[test]
190    /// factor(a, x) -> a / x
191    /// a = (a / x)x
192    fn test_factor_missing() {
193        let x = symbol!("x");
194        let a = symbol!("a");
195
196        assert_eq!(
197            factor_coeff(a, x).srepr(),
198            "Mul(Symbol(a), Pow(Symbol(x), Integer(-1)))"
199        );
200    }
201
202    #[test]
203    fn test_factor_coeff_laplacian_u() {
204        let u = function!("u");
205        let laplacian = symbol!("laplacian");
206        let c = symbol!("c");
207        let expr = c * 2 * laplacian * u;
208
209        let coeff = factor_coeff(&*expr, &*(laplacian * u));
210
211        assert_eq!(&coeff, &(c * 2));
212    }
213
214    #[test]
215    fn test_factor_coeff_no_div_trivial() {
216        let x = symbol!("x");
217
218        assert_eq!(factor_coeff_no_div(x, x), Some(Integer::one_box()));
219    }
220
221    #[test]
222    fn test_factor_coeff_no_div_missing() {
223        let [x, y] = symbols!("x", "y");
224
225        assert_eq!(factor_coeff_no_div(y, x), None);
226    }
227
228    #[test]
229    fn test_factor_coeff_no_div_laplacian_u() {
230        let u = function!("u");
231        let laplacian = symbol!("laplacian");
232        let c = symbol!("c");
233        let expr = c * 2 * laplacian * u;
234
235        let coeff = factor_coeff_no_div(&*expr, &*(laplacian * u));
236
237        assert_eq!(&coeff, &(Some(c * 2)));
238    }
239
240    #[test]
241    fn test_factor_coeff_no_div_not_enough() {
242        let c = symbol!("c");
243
244        assert_eq!(factor_coeff_no_div(c, c.ipow(2).get_ref()), None);
245    }
246
247    #[test]
248    fn test_factor_coeff_no_div_enough() {
249        let c = symbol!("c");
250
251        assert_eq!(
252            &factor_coeff_no_div(c.ipow(3).get_ref(), c.ipow(2).get_ref()),
253            &Some(c.clone_box())
254        );
255    }
256
257    #[test]
258    fn test_factor_basic() {
259        let x = symbol!("x");
260
261        let res = factor(x, &[x]);
262
263        assert_eq!(res.srepr(), "Symbol(x)");
264    }
265
266    #[test]
267    fn test_factor_simple() {
268        let x = symbol!("x");
269
270        let res = factor(&*(x + x * 2), &[x]);
271
272        assert_eq!(&res, &(x * 3));
273    }
274
275    #[test]
276    fn test_factor_multiple_options() {
277        let [u, v, w] = symbols!("u", "v", "w");
278        let res = factor(&*(v * 3 + w + u * v * 2 + u * 4), &[u, v]);
279
280        assert_eq!(&res, &((v * 2 + 4) * u + v * 3 + w));
281    }
282}