1use std::{
2 cmp::{self, Ordering},
3 collections::HashMap,
4 hash::Hash,
5};
6
7use indexmap::IndexMap;
8use num_traits::NumCast;
9
10use crate::{expr::*, ops::compare};
11
12pub fn factor<'a, E: Expr + ?Sized, F: Expr + ?Sized>(expr: &E, factors: &[&'a F]) -> Box<dyn Expr>
13where
14 &'a F: Hash + cmp::Eq,
15{
16 if let Some(eq) = expr.as_eq() {
17 return Equation::new_box(
18 factor(eq.lhs.get_ref(), factors),
19 factor(eq.rhs.get_ref(), factors),
20 );
21 }
22 let expr = expr.expand();
23 let mut factor_coeffs: HashMap<&F, Box<dyn Expr>> = HashMap::new();
24 let mut others: Vec<Box<dyn Expr>> = Vec::new();
25
26 let operands = match expr.known_expr() {
27 KnownExpr::Add(Add { operands }) => operands.iter().collect(),
28 _ => vec![&expr],
29 };
30
31 'ops: for op in operands {
34 for factor in factors {
35 if let Some(coeff) = factor_coeff_no_div(op.get_ref(), factor.get_ref()) {
36 let entry = factor_coeffs.entry(factor).or_insert(Integer::zero_box());
37 *entry += coeff;
38 continue 'ops;
39 }
40 }
41 others.push(op.clone_box());
42 }
43
44 let mut res_operands: Vec<Box<dyn Expr>> =
45 Vec::with_capacity(factor_coeffs.len() + others.len());
46
47 for fact in factors {
48 if let Some(coeff) = factor_coeffs.get(fact) {
49 if !coeff.is_zero() {
50 res_operands.push(factor(coeff.get_ref(), factors) * fact.get_ref());
51 }
52 }
53 }
54 res_operands.extend(others);
55
56 Add::new_box_v2(res_operands)
57}
58
59pub fn factor_coeff<E: Expr + ?Sized, F: Expr + ?Sized>(expr: &E, factor: &F) -> Box<dyn Expr> {
60 expr.get_ref() / factor.get_ref()
61}
62
63pub fn get_operands_exponents<E: Expr + ?Sized>(
64 expr: &E,
65) -> IndexMap<Box<dyn Expr>, Box<dyn Expr>> {
66 let mut operands_exponents: IndexMap<Box<dyn Expr>, Box<dyn Expr>> = IndexMap::new();
67
68 match expr.known_expr() {
69 KnownExpr::Pow(pow) => match pow.base.known_expr() {
70 KnownExpr::Mul(mul) => {
71 for (op, expo) in get_operands_exponents(mul) {
72 let entry = operands_exponents.entry(op).or_insert(Integer::zero_box());
73 *entry += expo * pow.exponent.clone_box();
74 }
75 }
76 _ => {
77 let entry = operands_exponents
78 .entry(pow.base.clone_box())
79 .or_insert(pow.exponent.clone_box());
80 *entry += Integer::one_box();
81 }
82 },
83 KnownExpr::Mul(mul) => {
84 for op in mul
85 .operands
86 .iter()
87 .flat_map(|op| match op.known_expr() {
89 KnownExpr::Mul(Mul { operands }) => operands.clone(),
90 KnownExpr::Pow(Pow { base, exponent })
91 if matches!(base.known_expr(), KnownExpr::Mul(Mul { .. })) =>
92 {
93 let mul = base.as_mul().unwrap();
94 mul.operands.iter().map(|op| op.pow(exponent)).collect()
95 }
96 _ => vec![op.clone_box()],
97 })
98 {
99 let (expr, exponent) = op.get_exponent();
100 let entry = operands_exponents
101 .entry(expr)
102 .or_insert(Integer::zero_box());
103 *entry += exponent;
104 }
105 }
106 _ => {
107 operands_exponents.insert(expr.clone_box(), Integer::one_box());
108 }
109 };
110
111 operands_exponents
112}
113
114pub fn factor_coeff_no_div<E: Expr + ?Sized, F: Expr + ?Sized>(
115 expr: &E,
116 factor: &F,
117) -> Option<Box<dyn Expr>> {
118 if expr.get_ref() == factor.get_ref() {
119 return Some(Integer::one_box());
120 }
121 let expr_op_expos = get_operands_exponents(expr);
122 let factor_op_expos = get_operands_exponents(factor);
123
124 for (op, factor_expo) in &factor_op_expos {
125 if let Some(expr_expo) = expr_op_expos.get(op) {
126 if expr_expo.has(factor_expo.get_ref()) {
127 continue;
128 }
129 if let Some(order) = compare(factor_expo.get_ref(), expr_expo.get_ref()) {
130 if order == Ordering::Greater {
131 return None;
132 }
133 } else {
134 return None;
135 }
136 } else {
137 return None;
138 }
139 }
140 Some(expr.get_ref() / factor.get_ref())
141}
142
143impl<N: NumCast> std::ops::Mul<N> for &dyn Expr {
144 type Output = Box<dyn Expr>;
145
146 fn mul(self, rhs: N) -> Self::Output {
147 let rhs = rhs.to_isize().unwrap();
148
149 Integer::new(rhs).get_ref() * self
150 }
151}
152
153#[macro_export]
154macro_rules! symbol {
155 ($name:expr) => {
156 &Symbol::new($name) as &dyn Expr
157 };
158}
159
160#[macro_export]
161macro_rules! function {
162 ($name:expr) => {
163 &Func::new($name, []) as &dyn Expr
164 };
165}
166
167#[cfg(test)]
168mod tests {
169 use crate::symbols;
170
171 use super::*;
172 #[test]
173 fn test_factor_trivial() {
174 assert_eq!(
175 Symbol::new("x")
176 .factor(&[Symbol::new("x").get_ref()])
177 .srepr(),
178 "Symbol(x)"
179 )
180 }
181
182 #[test]
183 fn test_factor_coeff_simple() {
184 let x = &Symbol::new("x") as &dyn Expr;
185
186 assert_eq!(factor_coeff(&*(x * 2), x).srepr(), "Integer(2)");
187 }
188
189 #[test]
190 fn test_factor_missing() {
193 let x = symbol!("x");
194 let a = symbol!("a");
195
196 assert_eq!(
197 factor_coeff(a, x).srepr(),
198 "Mul(Symbol(a), Pow(Symbol(x), Integer(-1)))"
199 );
200 }
201
202 #[test]
203 fn test_factor_coeff_laplacian_u() {
204 let u = function!("u");
205 let laplacian = symbol!("laplacian");
206 let c = symbol!("c");
207 let expr = c * 2 * laplacian * u;
208
209 let coeff = factor_coeff(&*expr, &*(laplacian * u));
210
211 assert_eq!(&coeff, &(c * 2));
212 }
213
214 #[test]
215 fn test_factor_coeff_no_div_trivial() {
216 let x = symbol!("x");
217
218 assert_eq!(factor_coeff_no_div(x, x), Some(Integer::one_box()));
219 }
220
221 #[test]
222 fn test_factor_coeff_no_div_missing() {
223 let [x, y] = symbols!("x", "y");
224
225 assert_eq!(factor_coeff_no_div(y, x), None);
226 }
227
228 #[test]
229 fn test_factor_coeff_no_div_laplacian_u() {
230 let u = function!("u");
231 let laplacian = symbol!("laplacian");
232 let c = symbol!("c");
233 let expr = c * 2 * laplacian * u;
234
235 let coeff = factor_coeff_no_div(&*expr, &*(laplacian * u));
236
237 assert_eq!(&coeff, &(Some(c * 2)));
238 }
239
240 #[test]
241 fn test_factor_coeff_no_div_not_enough() {
242 let c = symbol!("c");
243
244 assert_eq!(factor_coeff_no_div(c, c.ipow(2).get_ref()), None);
245 }
246
247 #[test]
248 fn test_factor_coeff_no_div_enough() {
249 let c = symbol!("c");
250
251 assert_eq!(
252 &factor_coeff_no_div(c.ipow(3).get_ref(), c.ipow(2).get_ref()),
253 &Some(c.clone_box())
254 );
255 }
256
257 #[test]
258 fn test_factor_basic() {
259 let x = symbol!("x");
260
261 let res = factor(x, &[x]);
262
263 assert_eq!(res.srepr(), "Symbol(x)");
264 }
265
266 #[test]
267 fn test_factor_simple() {
268 let x = symbol!("x");
269
270 let res = factor(&*(x + x * 2), &[x]);
271
272 assert_eq!(&res, &(x * 3));
273 }
274
275 #[test]
276 fn test_factor_multiple_options() {
277 let [u, v, w] = symbols!("u", "v", "w");
278 let res = factor(&*(v * 3 + w + u * v * 2 + u * 4), &[u, v]);
279
280 assert_eq!(&res, &((v * 2 + 4) * u + v * 3 + w));
281 }
282}