mathhook_core/core/polynomial/poly/
conversion.rs

1use super::IntPoly;
2use crate::core::{Expression, Number, Symbol};
3
4impl IntPoly {
5    /// Try to convert an Expression to IntPoly
6    ///
7    /// Returns Some(IntPoly) if the expression is a univariate polynomial
8    /// with integer coefficients. Returns None otherwise.
9    ///
10    /// # Example
11    /// ```rust
12    /// use mathhook_core::{symbol, expr};
13    /// use mathhook_core::core::polynomial::poly::IntPoly;
14    ///
15    /// let x = symbol!(x);
16    /// let poly = expr!((x^2) + (2*x) + 1);
17    /// if let Some(int_poly) = IntPoly::try_from_expression(&poly, &x) {
18    ///     let deriv = int_poly.derivative();
19    /// }
20    /// ```
21    pub fn try_from_expression(expr: &Expression, var: &Symbol) -> Option<Self> {
22        let mut coeffs = std::collections::HashMap::new();
23
24        if !extract_int_coefficients(expr, var, &mut coeffs) {
25            return None;
26        }
27
28        if coeffs.is_empty() {
29            return Some(Self::zero());
30        }
31
32        let max_deg = *coeffs.keys().max()?;
33        if max_deg > 1000 {
34            return None;
35        }
36
37        let mut coeff_vec = vec![0i64; max_deg as usize + 1];
38        for (deg, coeff) in coeffs {
39            if deg >= 0 {
40                coeff_vec[deg as usize] = coeff;
41            }
42        }
43
44        Some(Self::from_coeffs(coeff_vec))
45    }
46
47    /// Convert IntPoly back to Expression
48    ///
49    /// # Example
50    /// ```rust
51    /// use mathhook_core::symbol;
52    /// use mathhook_core::core::polynomial::poly::IntPoly;
53    ///
54    /// let x = symbol!(x);
55    /// let p = IntPoly::from_coeffs(vec![1, 2, 3]);
56    /// let expr = p.to_expression(&x);
57    /// ```
58    pub fn to_expression(&self, var: &Symbol) -> Expression {
59        if self.is_zero() {
60            return Expression::integer(0);
61        }
62
63        let mut terms = Vec::new();
64
65        for (i, &c) in self.coeffs.iter().enumerate() {
66            if c == 0 {
67                continue;
68            }
69
70            let term = match i {
71                0 => Expression::integer(c),
72                1 if c == 1 => Expression::symbol(var.clone()),
73                1 if c == -1 => Expression::mul(vec![
74                    Expression::integer(-1),
75                    Expression::symbol(var.clone()),
76                ]),
77                1 => Expression::mul(vec![
78                    Expression::integer(c),
79                    Expression::symbol(var.clone()),
80                ]),
81                _ if c == 1 => Expression::pow(
82                    Expression::symbol(var.clone()),
83                    Expression::integer(i as i64),
84                ),
85                _ if c == -1 => Expression::mul(vec![
86                    Expression::integer(-1),
87                    Expression::pow(
88                        Expression::symbol(var.clone()),
89                        Expression::integer(i as i64),
90                    ),
91                ]),
92                _ => Expression::mul(vec![
93                    Expression::integer(c),
94                    Expression::pow(
95                        Expression::symbol(var.clone()),
96                        Expression::integer(i as i64),
97                    ),
98                ]),
99            };
100
101            terms.push(term);
102        }
103
104        if terms.is_empty() {
105            Expression::integer(0)
106        } else if terms.len() == 1 {
107            terms.pop().unwrap()
108        } else {
109            Expression::add(terms)
110        }
111    }
112
113    /// Check if an Expression can be converted to IntPoly
114    ///
115    /// This is a fast check that doesn't allocate.
116    #[inline]
117    pub fn can_convert(expr: &Expression, var: &Symbol) -> bool {
118        is_int_polynomial(expr, var)
119    }
120}
121
122/// Extract integer coefficients from Expression
123///
124/// Returns false if any non-integer coefficient is found
125fn extract_int_coefficients(
126    expr: &Expression,
127    var: &Symbol,
128    coeffs: &mut std::collections::HashMap<i64, i64>,
129) -> bool {
130    match expr {
131        Expression::Number(Number::Integer(n)) => {
132            *coeffs.entry(0).or_insert(0) += n;
133            true
134        }
135        Expression::Number(Number::BigInteger(_)) => false,
136        Expression::Symbol(s) if s == var => {
137            *coeffs.entry(1).or_insert(0) += 1;
138            true
139        }
140        Expression::Symbol(_) => false,
141        Expression::Pow(base, exp) => {
142            if let (Expression::Symbol(s), Expression::Number(Number::Integer(n))) =
143                (base.as_ref(), exp.as_ref())
144            {
145                if s == var && *n >= 0 {
146                    *coeffs.entry(*n).or_insert(0) += 1;
147                    return true;
148                }
149            }
150            false
151        }
152        Expression::Mul(factors) => {
153            let mut coeff = 1i64;
154            let mut degree = 0i64;
155
156            for factor in factors.iter() {
157                match factor {
158                    Expression::Number(Number::Integer(n)) => match coeff.checked_mul(*n) {
159                        Some(c) => coeff = c,
160                        None => return false,
161                    },
162                    Expression::Symbol(s) if s == var => {
163                        degree += 1;
164                    }
165                    Expression::Pow(base, exp) => {
166                        if let (Expression::Symbol(s), Expression::Number(Number::Integer(n))) =
167                            (base.as_ref(), exp.as_ref())
168                        {
169                            if s == var && *n >= 0 {
170                                degree += *n;
171                            } else {
172                                return false;
173                            }
174                        } else {
175                            return false;
176                        }
177                    }
178                    _ => return false,
179                }
180            }
181
182            *coeffs.entry(degree).or_insert(0) += coeff;
183            true
184        }
185        Expression::Add(terms) => {
186            for term in terms.iter() {
187                if !extract_int_coefficients(term, var, coeffs) {
188                    return false;
189                }
190            }
191            true
192        }
193        _ => false,
194    }
195}
196
197/// Check if Expression is a polynomial with integer coefficients
198fn is_int_polynomial(expr: &Expression, var: &Symbol) -> bool {
199    match expr {
200        Expression::Number(Number::Integer(_)) => true,
201        Expression::Number(_) => false,
202        Expression::Symbol(s) => s == var,
203        Expression::Pow(base, exp) => {
204            if let (Expression::Symbol(s), Expression::Number(Number::Integer(n))) =
205                (base.as_ref(), exp.as_ref())
206            {
207                s == var && *n >= 0
208            } else {
209                false
210            }
211        }
212        Expression::Mul(factors) => {
213            let mut has_valid_var_term = true;
214            for factor in factors.iter() {
215                match factor {
216                    Expression::Number(Number::Integer(_)) => {}
217                    Expression::Symbol(s) if s == var => {}
218                    Expression::Pow(base, exp) => {
219                        if let (Expression::Symbol(s), Expression::Number(Number::Integer(n))) =
220                            (base.as_ref(), exp.as_ref())
221                        {
222                            if s != var || *n < 0 {
223                                has_valid_var_term = false;
224                            }
225                        } else {
226                            has_valid_var_term = false;
227                        }
228                    }
229                    _ => {
230                        has_valid_var_term = false;
231                    }
232                }
233            }
234            has_valid_var_term
235        }
236        Expression::Add(terms) => terms.iter().all(|t| is_int_polynomial(t, var)),
237        _ => false,
238    }
239}