mathhook_core/core/polynomial/
coefficients.rs

1//! Polynomial Coefficient Utilities
2//!
3//! Functions for extracting and manipulating polynomial coefficients.
4//! Provides utilities for working with polynomial coefficient lists
5//! and extracting coefficients at specific degrees.
6
7use crate::core::{Expression, Number, Symbol};
8use std::collections::HashMap;
9
10/// Extract all coefficients of a polynomial as a map from degree to coefficient
11///
12/// Returns a HashMap where keys are degrees and values are coefficients.
13///
14/// # Arguments
15///
16/// * `expr` - The polynomial expression
17/// * `var` - The variable to extract coefficients for
18///
19/// # Returns
20///
21/// A HashMap mapping degree (i64) to coefficient (Expression)
22///
23/// # Examples
24///
25/// ```rust
26/// use mathhook_core::core::polynomial::extract_coefficient_map;
27/// use mathhook_core::{expr, symbol};
28///
29/// let x = symbol!(x);
30/// let poly = expr!((3 * (x ^ 2)) + (2 * x) + 1);
31///
32/// let coeffs = extract_coefficient_map(&poly, &x);
33/// // coeffs[0] = 1, coeffs[1] = 2, coeffs[2] = 3
34/// ```
35pub fn extract_coefficient_map(expr: &Expression, var: &Symbol) -> HashMap<i64, Expression> {
36    let mut coefficients = HashMap::new();
37    extract_coefficients_recursive(expr, var, &mut coefficients);
38    coefficients
39}
40
41fn extract_coefficients_recursive(
42    expr: &Expression,
43    var: &Symbol,
44    coefficients: &mut HashMap<i64, Expression>,
45) {
46    match expr {
47        Expression::Number(_) => {
48            // Constant term has degree 0
49            add_coefficient(coefficients, 0, expr.clone());
50        }
51        Expression::Symbol(s) => {
52            if s == var {
53                // x has degree 1, coefficient 1
54                add_coefficient(coefficients, 1, Expression::integer(1));
55            } else {
56                // Other symbol is a constant (degree 0)
57                add_coefficient(coefficients, 0, expr.clone());
58            }
59        }
60        Expression::Add(terms) => {
61            for term in terms.iter() {
62                extract_coefficients_recursive(term, var, coefficients);
63            }
64        }
65        Expression::Mul(factors) => {
66            let (coef, deg) = extract_term_coefficient_and_degree(factors, var);
67            add_coefficient(coefficients, deg, coef);
68        }
69        Expression::Pow(base, exp) => {
70            if let Expression::Symbol(s) = base.as_ref() {
71                if s == var {
72                    if let Expression::Number(Number::Integer(n)) = exp.as_ref() {
73                        // x^n has degree n, coefficient 1
74                        add_coefficient(coefficients, *n, Expression::integer(1));
75                        return;
76                    }
77                }
78            }
79            // Non-variable power is a constant
80            add_coefficient(coefficients, 0, expr.clone());
81        }
82        _ => {
83            // Treat other expressions as constants
84            add_coefficient(coefficients, 0, expr.clone());
85        }
86    }
87}
88
89fn add_coefficient(coefficients: &mut HashMap<i64, Expression>, degree: i64, coef: Expression) {
90    coefficients
91        .entry(degree)
92        .and_modify(|existing| {
93            *existing = Expression::add(vec![existing.clone(), coef.clone()]);
94        })
95        .or_insert(coef);
96}
97
98fn extract_term_coefficient_and_degree(factors: &[Expression], var: &Symbol) -> (Expression, i64) {
99    let mut coefficient_factors = Vec::new();
100    let mut total_degree = 0i64;
101
102    for factor in factors.iter() {
103        match factor {
104            Expression::Symbol(s) if s == var => {
105                total_degree += 1;
106            }
107            Expression::Pow(base, exp) => {
108                if let Expression::Symbol(s) = base.as_ref() {
109                    if s == var {
110                        if let Expression::Number(Number::Integer(n)) = exp.as_ref() {
111                            total_degree += n;
112                            continue;
113                        }
114                    }
115                }
116                coefficient_factors.push(factor.clone());
117            }
118            _ => {
119                coefficient_factors.push(factor.clone());
120            }
121        }
122    }
123
124    let coef = if coefficient_factors.is_empty() {
125        Expression::integer(1)
126    } else if coefficient_factors.len() == 1 {
127        coefficient_factors.into_iter().next().unwrap()
128    } else {
129        Expression::mul(coefficient_factors)
130    };
131
132    (coef, total_degree)
133}
134
135/// Get the coefficient at a specific degree
136///
137/// # Arguments
138///
139/// * `expr` - The polynomial expression
140/// * `var` - The variable
141/// * `degree` - The degree to extract coefficient for
142///
143/// # Returns
144///
145/// The coefficient at the specified degree, or 0 if not present
146///
147/// # Examples
148///
149/// ```rust
150/// use mathhook_core::core::polynomial::coefficient_at;
151/// use mathhook_core::{expr, symbol};
152///
153/// let x = symbol!(x);
154/// let poly = expr!((3 * (x ^ 2)) + (2 * x) + 1);
155///
156/// let c2 = coefficient_at(&poly, &x, 2);  // Returns 3
157/// let c1 = coefficient_at(&poly, &x, 1);  // Returns 2
158/// let c0 = coefficient_at(&poly, &x, 0);  // Returns 1
159/// ```
160pub fn coefficient_at(expr: &Expression, var: &Symbol, degree: i64) -> Expression {
161    let coeffs = extract_coefficient_map(expr, var);
162    coeffs
163        .get(&degree)
164        .cloned()
165        .unwrap_or_else(|| Expression::integer(0))
166}
167
168/// Get all coefficients as a vector ordered by degree (ascending)
169///
170/// # Arguments
171///
172/// * `expr` - The polynomial expression
173/// * `var` - The variable
174///
175/// # Returns
176///
177/// A vector of (degree, coefficient) pairs ordered by degree
178///
179/// # Examples
180///
181/// ```rust
182/// use mathhook_core::core::polynomial::coefficients_list;
183/// use mathhook_core::{expr, symbol};
184///
185/// let x = symbol!(x);
186/// let poly = expr!((3 * (x ^ 2)) + (2 * x) + 1);
187///
188/// let coeffs = coefficients_list(&poly, &x);
189/// // Returns [(0, 1), (1, 2), (2, 3)]
190/// ```
191pub fn coefficients_list(expr: &Expression, var: &Symbol) -> Vec<(i64, Expression)> {
192    let map = extract_coefficient_map(expr, var);
193    let mut list: Vec<_> = map.into_iter().collect();
194    list.sort_by_key(|(deg, _)| *deg);
195    list
196}
197
198/// Extract the constant term (coefficient of degree 0)
199///
200/// # Arguments
201///
202/// * `expr` - The polynomial expression
203/// * `var` - The variable
204///
205/// # Returns
206///
207/// The constant term
208pub fn constant_term(expr: &Expression, var: &Symbol) -> Expression {
209    coefficient_at(expr, var, 0)
210}
211
212/// Check if polynomial is monic (leading coefficient is 1)
213///
214/// # Arguments
215///
216/// * `expr` - The polynomial expression
217/// * `var` - The variable
218///
219/// # Returns
220///
221/// True if the polynomial is monic
222pub fn is_monic(expr: &Expression, var: &Symbol) -> bool {
223    use super::properties::PolynomialProperties;
224    expr.leading_coefficient(var) == Expression::integer(1)
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use crate::symbol;
231
232    #[test]
233    fn test_extract_coefficient_map_simple() {
234        let x = symbol!(x);
235        // 3x^2 + 2x + 1
236        let poly = Expression::add(vec![
237            Expression::mul(vec![
238                Expression::integer(3),
239                Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
240            ]),
241            Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
242            Expression::integer(1),
243        ]);
244
245        let coeffs = extract_coefficient_map(&poly, &x);
246
247        assert!(coeffs.contains_key(&0));
248        assert!(coeffs.contains_key(&1));
249        assert!(coeffs.contains_key(&2));
250    }
251
252    #[test]
253    fn test_coefficient_at() {
254        let x = symbol!(x);
255        // x^2 + 2x + 3
256        let poly = Expression::add(vec![
257            Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
258            Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
259            Expression::integer(3),
260        ]);
261
262        let c0 = coefficient_at(&poly, &x, 0);
263        let c1 = coefficient_at(&poly, &x, 1);
264        let c2 = coefficient_at(&poly, &x, 2);
265        let c3 = coefficient_at(&poly, &x, 3);
266
267        assert_eq!(c0, Expression::integer(3));
268        assert_eq!(c1, Expression::integer(2));
269        assert_eq!(c2, Expression::integer(1));
270        assert_eq!(c3, Expression::integer(0));
271    }
272
273    #[test]
274    fn test_coefficients_list() {
275        let x = symbol!(x);
276        let poly = Expression::add(vec![
277            Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
278            Expression::integer(1),
279        ]);
280
281        let list = coefficients_list(&poly, &x);
282
283        assert_eq!(list.len(), 2);
284        assert_eq!(list[0].0, 0); // degree 0
285        assert_eq!(list[1].0, 2); // degree 2
286    }
287
288    #[test]
289    fn test_constant_term() {
290        let x = symbol!(x);
291        let poly = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(5)]);
292
293        assert_eq!(constant_term(&poly, &x), Expression::integer(5));
294    }
295
296    #[test]
297    fn test_is_monic() {
298        let x = symbol!(x);
299
300        // x^2 + 2x + 1 is monic
301        let monic = Expression::add(vec![
302            Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
303            Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
304            Expression::integer(1),
305        ]);
306        assert!(is_monic(&monic, &x));
307
308        // 2x^2 + x + 1 is NOT monic
309        let not_monic = Expression::add(vec![
310            Expression::mul(vec![
311                Expression::integer(2),
312                Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
313            ]),
314            Expression::symbol(x.clone()),
315            Expression::integer(1),
316        ]);
317        assert!(!is_monic(&not_monic, &x));
318    }
319}