mathhook_core/core/polynomial/
coefficients.rs1use crate::core::{Expression, Number, Symbol};
8use std::collections::HashMap;
9
10pub fn extract_coefficient_map(expr: &Expression, var: &Symbol) -> HashMap<i64, Expression> {
36 let mut coefficients = HashMap::new();
37 extract_coefficients_recursive(expr, var, &mut coefficients);
38 coefficients
39}
40
41fn extract_coefficients_recursive(
42 expr: &Expression,
43 var: &Symbol,
44 coefficients: &mut HashMap<i64, Expression>,
45) {
46 match expr {
47 Expression::Number(_) => {
48 add_coefficient(coefficients, 0, expr.clone());
50 }
51 Expression::Symbol(s) => {
52 if s == var {
53 add_coefficient(coefficients, 1, Expression::integer(1));
55 } else {
56 add_coefficient(coefficients, 0, expr.clone());
58 }
59 }
60 Expression::Add(terms) => {
61 for term in terms.iter() {
62 extract_coefficients_recursive(term, var, coefficients);
63 }
64 }
65 Expression::Mul(factors) => {
66 let (coef, deg) = extract_term_coefficient_and_degree(factors, var);
67 add_coefficient(coefficients, deg, coef);
68 }
69 Expression::Pow(base, exp) => {
70 if let Expression::Symbol(s) = base.as_ref() {
71 if s == var {
72 if let Expression::Number(Number::Integer(n)) = exp.as_ref() {
73 add_coefficient(coefficients, *n, Expression::integer(1));
75 return;
76 }
77 }
78 }
79 add_coefficient(coefficients, 0, expr.clone());
81 }
82 _ => {
83 add_coefficient(coefficients, 0, expr.clone());
85 }
86 }
87}
88
89fn add_coefficient(coefficients: &mut HashMap<i64, Expression>, degree: i64, coef: Expression) {
90 coefficients
91 .entry(degree)
92 .and_modify(|existing| {
93 *existing = Expression::add(vec![existing.clone(), coef.clone()]);
94 })
95 .or_insert(coef);
96}
97
98fn extract_term_coefficient_and_degree(factors: &[Expression], var: &Symbol) -> (Expression, i64) {
99 let mut coefficient_factors = Vec::new();
100 let mut total_degree = 0i64;
101
102 for factor in factors.iter() {
103 match factor {
104 Expression::Symbol(s) if s == var => {
105 total_degree += 1;
106 }
107 Expression::Pow(base, exp) => {
108 if let Expression::Symbol(s) = base.as_ref() {
109 if s == var {
110 if let Expression::Number(Number::Integer(n)) = exp.as_ref() {
111 total_degree += n;
112 continue;
113 }
114 }
115 }
116 coefficient_factors.push(factor.clone());
117 }
118 _ => {
119 coefficient_factors.push(factor.clone());
120 }
121 }
122 }
123
124 let coef = if coefficient_factors.is_empty() {
125 Expression::integer(1)
126 } else if coefficient_factors.len() == 1 {
127 coefficient_factors.into_iter().next().unwrap()
128 } else {
129 Expression::mul(coefficient_factors)
130 };
131
132 (coef, total_degree)
133}
134
135pub fn coefficient_at(expr: &Expression, var: &Symbol, degree: i64) -> Expression {
161 let coeffs = extract_coefficient_map(expr, var);
162 coeffs
163 .get(°ree)
164 .cloned()
165 .unwrap_or_else(|| Expression::integer(0))
166}
167
168pub fn coefficients_list(expr: &Expression, var: &Symbol) -> Vec<(i64, Expression)> {
192 let map = extract_coefficient_map(expr, var);
193 let mut list: Vec<_> = map.into_iter().collect();
194 list.sort_by_key(|(deg, _)| *deg);
195 list
196}
197
198pub fn constant_term(expr: &Expression, var: &Symbol) -> Expression {
209 coefficient_at(expr, var, 0)
210}
211
212pub fn is_monic(expr: &Expression, var: &Symbol) -> bool {
223 use super::properties::PolynomialProperties;
224 expr.leading_coefficient(var) == Expression::integer(1)
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use crate::symbol;
231
232 #[test]
233 fn test_extract_coefficient_map_simple() {
234 let x = symbol!(x);
235 let poly = Expression::add(vec![
237 Expression::mul(vec![
238 Expression::integer(3),
239 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
240 ]),
241 Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
242 Expression::integer(1),
243 ]);
244
245 let coeffs = extract_coefficient_map(&poly, &x);
246
247 assert!(coeffs.contains_key(&0));
248 assert!(coeffs.contains_key(&1));
249 assert!(coeffs.contains_key(&2));
250 }
251
252 #[test]
253 fn test_coefficient_at() {
254 let x = symbol!(x);
255 let poly = Expression::add(vec![
257 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
258 Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
259 Expression::integer(3),
260 ]);
261
262 let c0 = coefficient_at(&poly, &x, 0);
263 let c1 = coefficient_at(&poly, &x, 1);
264 let c2 = coefficient_at(&poly, &x, 2);
265 let c3 = coefficient_at(&poly, &x, 3);
266
267 assert_eq!(c0, Expression::integer(3));
268 assert_eq!(c1, Expression::integer(2));
269 assert_eq!(c2, Expression::integer(1));
270 assert_eq!(c3, Expression::integer(0));
271 }
272
273 #[test]
274 fn test_coefficients_list() {
275 let x = symbol!(x);
276 let poly = Expression::add(vec![
277 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
278 Expression::integer(1),
279 ]);
280
281 let list = coefficients_list(&poly, &x);
282
283 assert_eq!(list.len(), 2);
284 assert_eq!(list[0].0, 0); assert_eq!(list[1].0, 2); }
287
288 #[test]
289 fn test_constant_term() {
290 let x = symbol!(x);
291 let poly = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(5)]);
292
293 assert_eq!(constant_term(&poly, &x), Expression::integer(5));
294 }
295
296 #[test]
297 fn test_is_monic() {
298 let x = symbol!(x);
299
300 let monic = Expression::add(vec![
302 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
303 Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
304 Expression::integer(1),
305 ]);
306 assert!(is_monic(&monic, &x));
307
308 let not_monic = Expression::add(vec![
310 Expression::mul(vec![
311 Expression::integer(2),
312 Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
313 ]),
314 Expression::symbol(x.clone()),
315 Expression::integer(1),
316 ]);
317 assert!(!is_monic(¬_monic, &x));
318 }
319}