use crate::core::{Expression, Number, Symbol};
use std::collections::HashMap;
pub fn extract_coefficient_map(expr: &Expression, var: &Symbol) -> HashMap<i64, Expression> {
let mut coefficients = HashMap::new();
extract_coefficients_recursive(expr, var, &mut coefficients);
coefficients
}
fn extract_coefficients_recursive(
expr: &Expression,
var: &Symbol,
coefficients: &mut HashMap<i64, Expression>,
) {
match expr {
Expression::Number(_) => {
add_coefficient(coefficients, 0, expr.clone());
}
Expression::Symbol(s) => {
if s == var {
add_coefficient(coefficients, 1, Expression::integer(1));
} else {
add_coefficient(coefficients, 0, expr.clone());
}
}
Expression::Add(terms) => {
for term in terms.iter() {
extract_coefficients_recursive(term, var, coefficients);
}
}
Expression::Mul(factors) => {
let (coef, deg) = extract_term_coefficient_and_degree(factors, var);
add_coefficient(coefficients, deg, coef);
}
Expression::Pow(base, exp) => {
if let Expression::Symbol(s) = base.as_ref() {
if s == var {
if let Expression::Number(Number::Integer(n)) = exp.as_ref() {
add_coefficient(coefficients, *n, Expression::integer(1));
return;
}
}
}
add_coefficient(coefficients, 0, expr.clone());
}
_ => {
add_coefficient(coefficients, 0, expr.clone());
}
}
}
fn add_coefficient(coefficients: &mut HashMap<i64, Expression>, degree: i64, coef: Expression) {
coefficients
.entry(degree)
.and_modify(|existing| {
*existing = Expression::add(vec![existing.clone(), coef.clone()]);
})
.or_insert(coef);
}
fn extract_term_coefficient_and_degree(factors: &[Expression], var: &Symbol) -> (Expression, i64) {
let mut coefficient_factors = Vec::new();
let mut total_degree = 0i64;
for factor in factors.iter() {
match factor {
Expression::Symbol(s) if s == var => {
total_degree += 1;
}
Expression::Pow(base, exp) => {
if let Expression::Symbol(s) = base.as_ref() {
if s == var {
if let Expression::Number(Number::Integer(n)) = exp.as_ref() {
total_degree += n;
continue;
}
}
}
coefficient_factors.push(factor.clone());
}
_ => {
coefficient_factors.push(factor.clone());
}
}
}
let coef = if coefficient_factors.is_empty() {
Expression::integer(1)
} else if coefficient_factors.len() == 1 {
coefficient_factors.into_iter().next().unwrap()
} else {
Expression::mul(coefficient_factors)
};
(coef, total_degree)
}
pub fn coefficient_at(expr: &Expression, var: &Symbol, degree: i64) -> Expression {
let coeffs = extract_coefficient_map(expr, var);
coeffs
.get(°ree)
.cloned()
.unwrap_or_else(|| Expression::integer(0))
}
pub fn coefficients_list(expr: &Expression, var: &Symbol) -> Vec<(i64, Expression)> {
let map = extract_coefficient_map(expr, var);
let mut list: Vec<_> = map.into_iter().collect();
list.sort_by_key(|(deg, _)| *deg);
list
}
pub fn constant_term(expr: &Expression, var: &Symbol) -> Expression {
coefficient_at(expr, var, 0)
}
pub fn is_monic(expr: &Expression, var: &Symbol) -> bool {
use super::properties::PolynomialProperties;
expr.leading_coefficient(var) == Expression::integer(1)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::symbol;
#[test]
fn test_extract_coefficient_map_simple() {
let x = symbol!(x);
let poly = Expression::add(vec![
Expression::mul(vec![
Expression::integer(3),
Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
]),
Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
Expression::integer(1),
]);
let coeffs = extract_coefficient_map(&poly, &x);
assert!(coeffs.contains_key(&0));
assert!(coeffs.contains_key(&1));
assert!(coeffs.contains_key(&2));
}
#[test]
fn test_coefficient_at() {
let x = symbol!(x);
let poly = Expression::add(vec![
Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
Expression::integer(3),
]);
let c0 = coefficient_at(&poly, &x, 0);
let c1 = coefficient_at(&poly, &x, 1);
let c2 = coefficient_at(&poly, &x, 2);
let c3 = coefficient_at(&poly, &x, 3);
assert_eq!(c0, Expression::integer(3));
assert_eq!(c1, Expression::integer(2));
assert_eq!(c2, Expression::integer(1));
assert_eq!(c3, Expression::integer(0));
}
#[test]
fn test_coefficients_list() {
let x = symbol!(x);
let poly = Expression::add(vec![
Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
Expression::integer(1),
]);
let list = coefficients_list(&poly, &x);
assert_eq!(list.len(), 2);
assert_eq!(list[0].0, 0); assert_eq!(list[1].0, 2); }
#[test]
fn test_constant_term() {
let x = symbol!(x);
let poly = Expression::add(vec![Expression::symbol(x.clone()), Expression::integer(5)]);
assert_eq!(constant_term(&poly, &x), Expression::integer(5));
}
#[test]
fn test_is_monic() {
let x = symbol!(x);
let monic = Expression::add(vec![
Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
Expression::mul(vec![Expression::integer(2), Expression::symbol(x.clone())]),
Expression::integer(1),
]);
assert!(is_monic(&monic, &x));
let not_monic = Expression::add(vec![
Expression::mul(vec![
Expression::integer(2),
Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
]),
Expression::symbol(x.clone()),
Expression::integer(1),
]);
assert!(!is_monic(¬_monic, &x));
}
}