mathhook_core/core/polynomial/
classification.rs

1//! Polynomial Classification
2//!
3//! Automatic detection of polynomial structure for intelligent routing.
4//! The `PolynomialClassification` trait provides methods for checking
5//! polynomial structure and classifying expressions.
6
7use crate::core::expression::ExpressionClass;
8use crate::core::polynomial::poly::IntPoly;
9use crate::core::{Expression, Number, Symbol};
10
11/// Trait for polynomial classification
12///
13/// Provides automatic detection of polynomial structure.
14/// Implemented for `Expression` to enable classification-based
15/// algorithm routing.
16///
17/// # Examples
18///
19/// ```rust
20/// use mathhook_core::core::polynomial::PolynomialClassification;
21/// use mathhook_core::core::ExpressionClass;
22/// use mathhook_core::{expr, symbol};
23///
24/// let x = symbol!(x);
25/// let poly = expr!(x ^ 2 + x + 1);
26///
27/// assert!(poly.is_polynomial());
28/// assert_eq!(poly.polynomial_variables().len(), 1);
29///
30/// match poly.classify() {
31///     ExpressionClass::UnivariatePolynomial { degree, .. } => {
32///         assert_eq!(degree, 2);
33///     }
34///     _ => panic!("Expected univariate polynomial"),
35/// }
36/// ```
37pub trait PolynomialClassification {
38    /// Check if expression is a valid polynomial
39    ///
40    /// A polynomial is an expression composed only of:
41    /// - Constants (integers, rationals)
42    /// - Symbols (variables)
43    /// - Addition and subtraction
44    /// - Multiplication
45    /// - Non-negative integer powers
46    ///
47    /// # Examples
48    ///
49    /// ```rust
50    /// use mathhook_core::core::polynomial::PolynomialClassification;
51    /// use mathhook_core::{expr, symbol};
52    ///
53    /// let x = symbol!(x);
54    ///
55    /// // Polynomials
56    /// assert!(expr!(x ^ 2).is_polynomial());
57    /// assert!(expr!(x + 1).is_polynomial());
58    /// assert!(expr!(3 * x).is_polynomial());
59    ///
60    /// // Not polynomials
61    /// assert!(!expr!(sin(x)).is_polynomial());
62    /// ```
63    fn is_polynomial(&self) -> bool;
64
65    /// Check if polynomial in specific variables
66    ///
67    /// Returns true if the expression is a polynomial when treating
68    /// only the given variables as indeterminates (others are constants).
69    ///
70    /// # Arguments
71    ///
72    /// * `vars` - The variables to treat as indeterminates
73    ///
74    /// # Examples
75    ///
76    /// ```rust
77    /// use mathhook_core::core::polynomial::PolynomialClassification;
78    /// use mathhook_core::{expr, symbol};
79    ///
80    /// let x = symbol!(x);
81    /// let y = symbol!(y);
82    /// let poly = expr!(x * y);
83    ///
84    /// assert!(poly.is_polynomial_in(&[x.clone()]));
85    /// assert!(poly.is_polynomial_in(&[y.clone()]));
86    /// assert!(poly.is_polynomial_in(&[x.clone(), y.clone()]));
87    /// ```
88    fn is_polynomial_in(&self, vars: &[Symbol]) -> bool;
89
90    /// Get polynomial variables (empty if not polynomial)
91    ///
92    /// Returns all symbols that appear in the expression.
93    ///
94    /// # Examples
95    ///
96    /// ```rust
97    /// use mathhook_core::core::polynomial::PolynomialClassification;
98    /// use mathhook_core::{expr, symbol};
99    ///
100    /// let x = symbol!(x);
101    /// let y = symbol!(y);
102    /// let poly = expr!(x + y);
103    ///
104    /// let vars = poly.polynomial_variables();
105    /// assert_eq!(vars.len(), 2);
106    /// ```
107    fn polynomial_variables(&self) -> Vec<Symbol>;
108
109    /// Classify expression type for routing
110    ///
111    /// Returns the classification that determines which algorithm to use.
112    /// This enables intelligent dispatch to specialized algorithms.
113    ///
114    /// # Examples
115    ///
116    /// ```rust
117    /// use mathhook_core::core::polynomial::PolynomialClassification;
118    /// use mathhook_core::core::ExpressionClass;
119    /// use mathhook_core::{expr, symbol};
120    ///
121    /// let x = symbol!(x);
122    ///
123    /// // Integer classification
124    /// assert_eq!(expr!(5).classify(), ExpressionClass::Integer);
125    ///
126    /// // Univariate polynomial
127    /// match expr!(x ^ 2).classify() {
128    ///     ExpressionClass::UnivariatePolynomial { degree, .. } => {
129    ///         assert_eq!(degree, 2);
130    ///     }
131    ///     _ => panic!("Expected univariate polynomial"),
132    /// }
133    /// ```
134    fn classify(&self) -> ExpressionClass;
135
136    /// Check if expression can be represented as IntPoly
137    ///
138    /// Returns true if the expression is a univariate polynomial
139    /// with integer coefficients only. This is a fast heuristic check.
140    ///
141    /// # Examples
142    ///
143    /// ```rust
144    /// use mathhook_core::core::polynomial::PolynomialClassification;
145    /// use mathhook_core::{expr, symbol};
146    ///
147    /// let x = symbol!(x);
148    ///
149    /// assert!(expr!(x ^ 2 + 2 * x + 1).is_intpoly_compatible());
150    /// assert!(!expr!(1.5 * x + 2).is_intpoly_compatible());
151    /// ```
152    fn is_intpoly_compatible(&self) -> bool;
153
154    /// Try to convert to IntPoly
155    ///
156    /// Returns IntPoly and variable if expression is a univariate
157    /// integer polynomial, None otherwise.
158    ///
159    /// # Examples
160    ///
161    /// ```rust
162    /// use mathhook_core::core::polynomial::PolynomialClassification;
163    /// use mathhook_core::{expr, symbol};
164    ///
165    /// let x = symbol!(x);
166    /// let poly_expr = expr!(x ^ 2 + 2 * x + 3);
167    ///
168    /// if let Some((intpoly, var)) = poly_expr.try_as_intpoly() {
169    ///     assert_eq!(var, x);
170    ///     assert_eq!(intpoly.degree(), Some(2));
171    /// }
172    /// ```
173    fn try_as_intpoly(&self) -> Option<(IntPoly, Symbol)>;
174}
175
176impl PolynomialClassification for Expression {
177    fn is_polynomial(&self) -> bool {
178        is_polynomial_impl(self)
179    }
180
181    fn is_polynomial_in(&self, vars: &[Symbol]) -> bool {
182        is_polynomial_in_impl(self, vars)
183    }
184
185    fn polynomial_variables(&self) -> Vec<Symbol> {
186        collect_polynomial_variables(self)
187    }
188
189    fn classify(&self) -> ExpressionClass {
190        classify_impl(self)
191    }
192
193    fn is_intpoly_compatible(&self) -> bool {
194        let vars = self.polynomial_variables();
195        if vars.len() != 1 {
196            return false;
197        }
198        has_only_integer_coefficients(self)
199    }
200
201    fn try_as_intpoly(&self) -> Option<(IntPoly, Symbol)> {
202        let vars = self.polynomial_variables();
203        if vars.len() != 1 {
204            return None;
205        }
206        let var = &vars[0];
207        IntPoly::try_from_expression(self, var).map(|poly| (poly, var.clone()))
208    }
209}
210
211/// Extract integer value from expression if it's an integer
212fn extract_integer(expr: &Expression) -> Option<i64> {
213    match expr {
214        Expression::Number(Number::Integer(n)) => Some(*n),
215        _ => None,
216    }
217}
218
219/// Check if expression is a rational number
220fn is_rational(expr: &Expression) -> bool {
221    matches!(expr, Expression::Number(Number::Rational(_)))
222}
223
224/// Check if expression is a polynomial (no transcendental functions, positive powers only)
225fn is_polynomial_impl(expr: &Expression) -> bool {
226    match expr {
227        Expression::Number(_) => true,
228        Expression::Symbol(_) => true,
229        Expression::Add(terms) | Expression::Mul(terms) => terms.iter().all(is_polynomial_impl),
230        Expression::Pow(base, exp) => {
231            if !is_polynomial_impl(base) {
232                return false;
233            }
234            if let Some(n) = extract_integer(exp) {
235                n >= 0
236            } else {
237                false
238            }
239        }
240        Expression::Function { .. } => false,
241        _ => false,
242    }
243}
244
245/// Check if expression is a polynomial in specific variables
246fn is_polynomial_in_impl(expr: &Expression, vars: &[Symbol]) -> bool {
247    match expr {
248        Expression::Number(_) => true,
249        Expression::Symbol(_s) => true,
250        Expression::Add(terms) | Expression::Mul(terms) => {
251            terms.iter().all(|t| is_polynomial_in_impl(t, vars))
252        }
253        Expression::Pow(base, exp) => {
254            if !is_polynomial_in_impl(base, vars) {
255                return false;
256            }
257            if let Some(n) = extract_integer(exp) {
258                n >= 0
259            } else {
260                let exp_vars = collect_polynomial_variables(exp);
261                !exp_vars.iter().any(|v| vars.contains(v))
262            }
263        }
264        Expression::Function { .. } => false,
265        _ => false,
266    }
267}
268
269/// Collect all variables from a polynomial expression
270fn collect_polynomial_variables(expr: &Expression) -> Vec<Symbol> {
271    use std::collections::HashSet;
272    let mut vars = HashSet::new();
273    collect_vars_impl(expr, &mut vars);
274    vars.into_iter().collect()
275}
276
277fn collect_vars_impl(expr: &Expression, vars: &mut std::collections::HashSet<Symbol>) {
278    match expr {
279        Expression::Symbol(s) => {
280            vars.insert(s.clone());
281        }
282        Expression::Add(terms) | Expression::Mul(terms) => {
283            for term in terms.iter() {
284                collect_vars_impl(term, vars);
285            }
286        }
287        Expression::Pow(base, exp) => {
288            collect_vars_impl(base, vars);
289            collect_vars_impl(exp, vars);
290        }
291        _ => {}
292    }
293}
294
295/// Classify expression for algorithm routing
296fn classify_impl(expr: &Expression) -> ExpressionClass {
297    if extract_integer(expr).is_some() {
298        return ExpressionClass::Integer;
299    }
300
301    if !is_polynomial_impl(expr) {
302        if contains_transcendental(expr) {
303            return ExpressionClass::Transcendental;
304        }
305        return ExpressionClass::Symbolic;
306    }
307
308    let vars = collect_polynomial_variables(expr);
309
310    match vars.len() {
311        0 => {
312            if is_rational(expr) {
313                ExpressionClass::Rational
314            } else {
315                ExpressionClass::Integer
316            }
317        }
318        1 => {
319            let var = vars.into_iter().next().unwrap();
320            let degree = compute_degree(expr, &var).unwrap_or(0);
321            ExpressionClass::UnivariatePolynomial { var, degree }
322        }
323        _ => {
324            let total_degree = vars.iter().filter_map(|v| compute_degree(expr, v)).sum();
325            ExpressionClass::MultivariatePolynomial { vars, total_degree }
326        }
327    }
328}
329
330/// Check if expression contains transcendental functions
331fn contains_transcendental(expr: &Expression) -> bool {
332    match expr {
333        Expression::Function { name, .. } => {
334            let transcendental_fns = [
335                "sin", "cos", "tan", "cot", "sec", "csc", "sinh", "cosh", "tanh", "exp", "log",
336                "ln", "arcsin", "arccos", "arctan",
337            ];
338            transcendental_fns.contains(&name.as_ref())
339        }
340        Expression::Add(terms) | Expression::Mul(terms) => {
341            terms.iter().any(contains_transcendental)
342        }
343        Expression::Pow(base, exp) => contains_transcendental(base) || contains_transcendental(exp),
344        _ => false,
345    }
346}
347
348/// Compute degree of polynomial with respect to a variable
349fn compute_degree(expr: &Expression, var: &Symbol) -> Option<i64> {
350    match expr {
351        Expression::Number(_) => Some(0),
352        Expression::Symbol(s) => {
353            if s == var {
354                Some(1)
355            } else {
356                Some(0)
357            }
358        }
359        Expression::Add(terms) => terms.iter().filter_map(|t| compute_degree(t, var)).max(),
360        Expression::Mul(terms) => {
361            let degrees: Option<Vec<i64>> = terms.iter().map(|t| compute_degree(t, var)).collect();
362            degrees.map(|ds| ds.into_iter().sum())
363        }
364        Expression::Pow(base, exp) => {
365            let base_deg = compute_degree(base, var)?;
366            let exp_val = extract_integer(exp)?;
367            Some(base_deg * exp_val)
368        }
369        _ => None,
370    }
371}
372
373/// Check if expression has only integer coefficients
374fn has_only_integer_coefficients(expr: &Expression) -> bool {
375    match expr {
376        Expression::Number(Number::Integer(_)) => true,
377        Expression::Symbol(_) => true,
378        Expression::Add(terms) | Expression::Mul(terms) => {
379            terms.iter().all(has_only_integer_coefficients)
380        }
381        Expression::Pow(base, exp) => {
382            has_only_integer_coefficients(base)
383                && matches!(exp.as_ref(), Expression::Number(Number::Integer(n)) if *n >= 0)
384        }
385        _ => false,
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392    use crate::{expr, symbol};
393
394    #[test]
395    fn test_is_polynomial() {
396        let x = symbol!(x);
397
398        assert!(Expression::integer(5).is_polynomial());
399        assert!(Expression::symbol(x.clone()).is_polynomial());
400
401        let poly = expr!(x + 1);
402        assert!(poly.is_polynomial());
403
404        let poly2 = expr!(x ^ 2);
405        assert!(poly2.is_polynomial());
406    }
407
408    #[test]
409    fn test_classify_integer() {
410        let five = Expression::integer(5);
411        assert_eq!(five.classify(), ExpressionClass::Integer);
412    }
413
414    #[test]
415    fn test_classify_univariate() {
416        let x = symbol!(x);
417        let poly = expr!(x ^ 2);
418
419        match poly.classify() {
420            ExpressionClass::UnivariatePolynomial { var, degree } => {
421                assert_eq!(var, x);
422                assert_eq!(degree, 2);
423            }
424            other => panic!("Expected UnivariatePolynomial, got {:?}", other),
425        }
426    }
427
428    #[test]
429    fn test_polynomial_variables() {
430        let x = symbol!(x);
431        let y = symbol!(y);
432
433        let poly = Expression::add(vec![
434            Expression::symbol(x.clone()),
435            Expression::symbol(y.clone()),
436        ]);
437
438        let vars = poly.polynomial_variables();
439        assert_eq!(vars.len(), 2);
440        assert!(vars.contains(&x));
441        assert!(vars.contains(&y));
442    }
443
444    #[test]
445    fn test_is_polynomial_in() {
446        let x = symbol!(x);
447        let y = symbol!(y);
448        let poly = expr!(x * y);
449
450        assert!(poly.is_polynomial_in(std::slice::from_ref(&x)));
451        assert!(poly.is_polynomial_in(std::slice::from_ref(&y)));
452        assert!(poly.is_polynomial_in(&[x.clone(), y.clone()]));
453    }
454
455    #[test]
456    fn test_classify_multivariate() {
457        let x = symbol!(x);
458        let y = symbol!(y);
459        let poly = Expression::add(vec![
460            Expression::symbol(x.clone()),
461            Expression::symbol(y.clone()),
462        ]);
463
464        match poly.classify() {
465            ExpressionClass::MultivariatePolynomial { vars, .. } => {
466                assert_eq!(vars.len(), 2);
467                assert!(vars.contains(&x));
468                assert!(vars.contains(&y));
469            }
470            other => panic!("Expected MultivariatePolynomial, got {:?}", other),
471        }
472    }
473
474    #[test]
475    fn test_classify_transcendental() {
476        let x = symbol!(x);
477        let expr = Expression::function("sin", vec![Expression::symbol(x)]);
478
479        assert_eq!(expr.classify(), ExpressionClass::Transcendental);
480    }
481
482    #[test]
483    fn test_is_intpoly_compatible() {
484        assert!(expr!(2 * x + 3).is_intpoly_compatible());
485        assert!(expr!(x ^ 2 + 2 * x + 1).is_intpoly_compatible());
486
487        assert!(!expr!(x + y).is_intpoly_compatible());
488
489        assert!(!expr!(1.5 * x + 2).is_intpoly_compatible());
490
491        assert!(!expr!(x ^ (-1)).is_intpoly_compatible());
492    }
493
494    #[test]
495    fn test_try_as_intpoly() {
496        let x = symbol!(x);
497        let poly_expr = expr!(x ^ 2 + 2 * x + 3);
498
499        let result = poly_expr.try_as_intpoly();
500        assert!(result.is_some());
501
502        let (intpoly, var) = result.unwrap();
503        assert_eq!(var, x);
504        assert_eq!(intpoly.degree(), Some(2));
505        assert_eq!(intpoly.coefficients(), &[3, 2, 1]);
506    }
507}