mathhook_core/core/polynomial/
dispatch.rs

1//! Unified polynomial dispatch layer
2//!
3//! Auto-routes Expression to optimal `Poly<T>` implementation based on coefficient type.
4//! Converts ONCE at entry → stays numeric → converts ONCE at exit.
5//!
6//! # Architecture
7//!
8//! 1. Analyze coefficient types in Expression
9//! 2. Route to optimal implementation:
10//!    - All integers → IntPoly (fastest path)
11//!    - Any rationals → RationalPoly (field operations)
12//!    - Multivariate → symbolic fallback
13//! 3. Convert result back to Expression ONCE
14//!
15//! # Example
16//!
17//! ```rust
18//! use mathhook_core::{expr, symbol};
19//! use mathhook_core::core::polynomial::dispatch::polynomial_gcd;
20//!
21//! let x = symbol!(x);
22//! let p1 = expr!((x^2) - 1);
23//! let p2 = expr!(x - 1);
24//! let gcd = polynomial_gcd(&p1, &p2, &x);
25//! ```
26
27use crate::core::polynomial::poly::{IntPoly, RationalPoly};
28use crate::core::{Expression, Number, Symbol};
29use num_bigint::BigInt;
30use num_rational::BigRational;
31use num_rational::Ratio;
32use num_traits::ToPrimitive;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35enum CoefficientType {
36    Integer,
37    Rational,
38    Symbolic,
39}
40
41/// Analyze coefficient types in an Expression
42///
43/// Returns the most general coefficient type needed:
44/// - Integer: all coefficients are integers
45/// - Rational: at least one rational coefficient
46/// - Symbolic: contains symbolic expressions (not polynomial)
47fn analyze_coefficient_type(expr: &Expression, var: &Symbol) -> CoefficientType {
48    match expr {
49        Expression::Number(Number::Integer(_)) => CoefficientType::Integer,
50        Expression::Number(Number::Rational(_)) => CoefficientType::Rational,
51        Expression::Number(Number::Float(_)) => CoefficientType::Rational,
52        Expression::Number(Number::BigInteger(_)) => CoefficientType::Symbolic,
53        Expression::Symbol(s) if s == var => CoefficientType::Integer,
54        Expression::Symbol(_) => CoefficientType::Symbolic,
55        Expression::Pow(base, exp) => {
56            if let (Expression::Symbol(s), Expression::Number(Number::Integer(n))) =
57                (base.as_ref(), exp.as_ref())
58            {
59                if s == var && *n >= 0 {
60                    CoefficientType::Integer
61                } else {
62                    CoefficientType::Symbolic
63                }
64            } else {
65                CoefficientType::Symbolic
66            }
67        }
68        Expression::Mul(factors) => {
69            let mut has_rational = false;
70            for factor in factors.iter() {
71                match analyze_coefficient_type(factor, var) {
72                    CoefficientType::Symbolic => return CoefficientType::Symbolic,
73                    CoefficientType::Rational => has_rational = true,
74                    CoefficientType::Integer => {}
75                }
76            }
77            if has_rational {
78                CoefficientType::Rational
79            } else {
80                CoefficientType::Integer
81            }
82        }
83        Expression::Add(terms) => {
84            let mut has_rational = false;
85            for term in terms.iter() {
86                match analyze_coefficient_type(term, var) {
87                    CoefficientType::Symbolic => return CoefficientType::Symbolic,
88                    CoefficientType::Rational => has_rational = true,
89                    CoefficientType::Integer => {}
90                }
91            }
92            if has_rational {
93                CoefficientType::Rational
94            } else {
95                CoefficientType::Integer
96            }
97        }
98        _ => CoefficientType::Symbolic,
99    }
100}
101
102/// Unified polynomial GCD with automatic type routing
103///
104/// Analyzes coefficient types and routes to optimal implementation:
105/// - All integers → IntPoly GCD (fastest)
106/// - Any rationals → RationalPoly GCD (field operations)
107/// - Symbolic → fallback to Euclidean algorithm
108///
109/// # Arguments
110/// * `a` - First polynomial
111/// * `b` - Second polynomial
112/// * `var` - Variable to treat as polynomial variable
113///
114/// # Example
115/// ```rust
116/// use mathhook_core::{expr, symbol};
117/// use mathhook_core::core::polynomial::dispatch::polynomial_gcd;
118///
119/// let x = symbol!(x);
120/// let p1 = expr!((x^2) - 1);
121/// let p2 = expr!(x - 1);
122/// let gcd = polynomial_gcd(&p1, &p2, &x);
123/// ```
124pub fn polynomial_gcd(a: &Expression, b: &Expression, var: &Symbol) -> Expression {
125    let a_type = analyze_coefficient_type(a, var);
126    let b_type = analyze_coefficient_type(b, var);
127
128    match (a_type, b_type) {
129        (CoefficientType::Integer, CoefficientType::Integer) => {
130            if let (Some(poly_a), Some(poly_b)) = (
131                IntPoly::try_from_expression(a, var),
132                IntPoly::try_from_expression(b, var),
133            ) {
134                if let Ok(gcd_poly) = poly_a.gcd_i64(&poly_b) {
135                    return gcd_poly.to_expression(var);
136                }
137            }
138        }
139        (CoefficientType::Rational, _) | (_, CoefficientType::Rational) => {
140            if let (Some(poly_a), Some(poly_b)) = (
141                try_rational_poly_from_expression(a, var),
142                try_rational_poly_from_expression(b, var),
143            ) {
144                if let Ok(gcd_poly) = poly_a.gcd(&poly_b) {
145                    return rational_poly_to_expression(&gcd_poly, var);
146                }
147            }
148        }
149        _ => {}
150    }
151
152    symbolic_gcd(a, b, var)
153}
154
155/// Unified polynomial division with automatic type routing
156///
157/// Returns (quotient, remainder) such that:
158/// `dividend = divisor * quotient + remainder`
159///
160/// # Arguments
161/// * `dividend` - Polynomial to divide
162/// * `divisor` - Polynomial to divide by
163/// * `var` - Variable to treat as polynomial variable
164///
165/// # Example
166/// ```rust
167/// use mathhook_core::{expr, symbol};
168/// use mathhook_core::core::polynomial::dispatch::polynomial_div;
169///
170/// let x = symbol!(x);
171/// let dividend = expr!((x^2) + (3*x) + 2);
172/// let divisor = expr!(x + 1);
173/// let (quot, rem) = polynomial_div(&dividend, &divisor, &x);
174/// ```
175pub fn polynomial_div(
176    dividend: &Expression,
177    divisor: &Expression,
178    var: &Symbol,
179) -> (Expression, Expression) {
180    if divisor.is_zero() {
181        return (Expression::undefined(), Expression::undefined());
182    }
183
184    if dividend.is_zero() {
185        return (Expression::integer(0), Expression::integer(0));
186    }
187
188    if dividend == divisor {
189        return (Expression::integer(1), Expression::integer(0));
190    }
191
192    let dividend_type = analyze_coefficient_type(dividend, var);
193    let divisor_type = analyze_coefficient_type(divisor, var);
194
195    match (dividend_type, divisor_type) {
196        (CoefficientType::Integer, CoefficientType::Integer) => {
197            if let (Some(poly_dividend), Some(poly_divisor)) = (
198                IntPoly::try_from_expression(dividend, var),
199                IntPoly::try_from_expression(divisor, var),
200            ) {
201                if let Ok((q, r)) = poly_dividend.div_rem(&poly_divisor) {
202                    return (q.to_expression(var), r.to_expression(var));
203                }
204            }
205        }
206        (CoefficientType::Rational, _) | (_, CoefficientType::Rational) => {
207            if let (Some(poly_dividend), Some(poly_divisor)) = (
208                try_rational_poly_from_expression(dividend, var),
209                try_rational_poly_from_expression(divisor, var),
210            ) {
211                if let Ok((q, r)) = poly_dividend.div_rem(&poly_divisor) {
212                    return (
213                        rational_poly_to_expression(&q, var),
214                        rational_poly_to_expression(&r, var),
215                    );
216                }
217            }
218        }
219        _ => {}
220    }
221
222    symbolic_div(dividend, divisor, var)
223}
224
225/// Unified polynomial remainder with automatic type routing
226///
227/// # Arguments
228/// * `dividend` - Polynomial to divide
229/// * `divisor` - Polynomial to divide by
230/// * `var` - Variable to treat as polynomial variable
231///
232/// # Example
233/// ```rust
234/// use mathhook_core::{expr, symbol};
235/// use mathhook_core::core::polynomial::dispatch::polynomial_rem;
236///
237/// let x = symbol!(x);
238/// let dividend = expr!((x^2) + 1);
239/// let divisor = expr!(x - 1);
240/// let rem = polynomial_rem(&dividend, &divisor, &x);
241/// ```
242pub fn polynomial_rem(dividend: &Expression, divisor: &Expression, var: &Symbol) -> Expression {
243    polynomial_div(dividend, divisor, var).1
244}
245
246/// Unified polynomial quotient with automatic type routing
247///
248/// # Arguments
249/// * `dividend` - Polynomial to divide
250/// * `divisor` - Polynomial to divide by
251/// * `var` - Variable to treat as polynomial variable
252///
253/// # Example
254/// ```rust
255/// use mathhook_core::{expr, symbol};
256/// use mathhook_core::core::polynomial::dispatch::polynomial_quo;
257///
258/// let x = symbol!(x);
259/// let dividend = expr!((x^2) - 1);
260/// let divisor = expr!(x - 1);
261/// let quot = polynomial_quo(&dividend, &divisor, &x);
262/// ```
263pub fn polynomial_quo(dividend: &Expression, divisor: &Expression, var: &Symbol) -> Expression {
264    polynomial_div(dividend, divisor, var).0
265}
266
267/// Try to convert Expression to RationalPoly
268fn try_rational_poly_from_expression(expr: &Expression, var: &Symbol) -> Option<RationalPoly> {
269    let mut coeffs = std::collections::HashMap::new();
270
271    if !extract_rational_coefficients(expr, var, &mut coeffs) {
272        return None;
273    }
274
275    if coeffs.is_empty() {
276        return Some(RationalPoly::zero());
277    }
278
279    let max_deg = *coeffs.keys().max()?;
280    if max_deg > 1000 {
281        return None;
282    }
283
284    let mut coeff_vec = vec![Ratio::new(0, 1); max_deg as usize + 1];
285    for (deg, coeff) in coeffs {
286        if deg >= 0 {
287            coeff_vec[deg as usize] = coeff;
288        }
289    }
290
291    Some(RationalPoly::from_coeffs(coeff_vec))
292}
293
294/// Try to convert BigRational to Ratio<i64>
295fn try_bigrational_to_ratio(r: &num_rational::BigRational) -> Option<Ratio<i64>> {
296    let numer = r.numer().to_i64()?;
297    let denom = r.denom().to_i64()?;
298    Some(Ratio::new(numer, denom))
299}
300
301/// Extract rational coefficients from Expression
302fn extract_rational_coefficients(
303    expr: &Expression,
304    var: &Symbol,
305    coeffs: &mut std::collections::HashMap<i64, Ratio<i64>>,
306) -> bool {
307    match expr {
308        Expression::Number(Number::Integer(n)) => {
309            let entry = coeffs.entry(0).or_insert_with(|| Ratio::new(0, 1));
310            *entry += Ratio::new(*n, 1);
311            true
312        }
313        Expression::Number(Number::Rational(r)) => {
314            if let Some(ratio) = try_bigrational_to_ratio(r) {
315                let entry = coeffs.entry(0).or_insert_with(|| Ratio::new(0, 1));
316                *entry += ratio;
317                true
318            } else {
319                false
320            }
321        }
322        Expression::Number(Number::Float(f)) => {
323            let approx = (*f * 1000000.0).round() as i64;
324            let entry = coeffs.entry(0).or_insert_with(|| Ratio::new(0, 1));
325            *entry += Ratio::new(approx, 1000000);
326            true
327        }
328        Expression::Number(_) => false,
329        Expression::Symbol(s) if s == var => {
330            let entry = coeffs.entry(1).or_insert_with(|| Ratio::new(0, 1));
331            *entry += Ratio::new(1, 1);
332            true
333        }
334        Expression::Symbol(_) => false,
335        Expression::Pow(base, exp) => {
336            if let (Expression::Symbol(s), Expression::Number(Number::Integer(n))) =
337                (base.as_ref(), exp.as_ref())
338            {
339                if s == var && *n >= 0 {
340                    let entry = coeffs.entry(*n).or_insert_with(|| Ratio::new(0, 1));
341                    *entry += Ratio::new(1, 1);
342                    return true;
343                }
344            }
345            false
346        }
347        Expression::Mul(factors) => {
348            let mut coeff = Ratio::new(1, 1);
349            let mut degree = 0i64;
350
351            for factor in factors.iter() {
352                match factor {
353                    Expression::Number(Number::Integer(n)) => {
354                        coeff *= Ratio::new(*n, 1);
355                    }
356                    Expression::Number(Number::Rational(r)) => {
357                        if let Some(ratio) = try_bigrational_to_ratio(r) {
358                            coeff *= ratio;
359                        } else {
360                            return false;
361                        }
362                    }
363                    Expression::Number(Number::Float(f)) => {
364                        let approx = (*f * 1000000.0).round() as i64;
365                        coeff *= Ratio::new(approx, 1000000);
366                    }
367                    Expression::Symbol(s) if s == var => {
368                        degree += 1;
369                    }
370                    Expression::Pow(base, exp) => {
371                        if let (Expression::Symbol(s), Expression::Number(Number::Integer(n))) =
372                            (base.as_ref(), exp.as_ref())
373                        {
374                            if s == var && *n >= 0 {
375                                degree += *n;
376                            } else {
377                                return false;
378                            }
379                        } else {
380                            return false;
381                        }
382                    }
383                    _ => return false,
384                }
385            }
386
387            let entry = coeffs.entry(degree).or_insert_with(|| Ratio::new(0, 1));
388            *entry += coeff;
389            true
390        }
391        Expression::Add(terms) => {
392            for term in terms.iter() {
393                if !extract_rational_coefficients(term, var, coeffs) {
394                    return false;
395                }
396            }
397            true
398        }
399        _ => false,
400    }
401}
402
403/// Convert RationalPoly to Expression
404fn rational_poly_to_expression(poly: &RationalPoly, var: &Symbol) -> Expression {
405    if poly.is_zero() {
406        return Expression::integer(0);
407    }
408
409    let mut terms = Vec::new();
410
411    for (i, c) in poly.coefficients().iter().enumerate() {
412        if c.numer() == &0 {
413            continue;
414        }
415
416        let coeff_expr = if c.denom() == &1 {
417            Expression::integer(*c.numer())
418        } else {
419            Expression::Number(Number::rational(BigRational::new(
420                BigInt::from(*c.numer()),
421                BigInt::from(*c.denom()),
422            )))
423        };
424
425        let term = match i {
426            0 => coeff_expr,
427            1 if c.numer() == &1 && c.denom() == &1 => Expression::symbol(var.clone()),
428            1 => Expression::mul(vec![coeff_expr, Expression::symbol(var.clone())]),
429            _ if c.numer() == &1 && c.denom() == &1 => Expression::pow(
430                Expression::symbol(var.clone()),
431                Expression::integer(i as i64),
432            ),
433            _ => Expression::mul(vec![
434                coeff_expr,
435                Expression::pow(
436                    Expression::symbol(var.clone()),
437                    Expression::integer(i as i64),
438                ),
439            ]),
440        };
441
442        terms.push(term);
443    }
444
445    if terms.is_empty() {
446        Expression::integer(0)
447    } else if terms.len() == 1 {
448        terms.pop().unwrap()
449    } else {
450        Expression::add(terms)
451    }
452}
453
454/// Symbolic GCD fallback using Euclidean algorithm
455fn symbolic_gcd(p1: &Expression, p2: &Expression, var: &Symbol) -> Expression {
456    let mut a = p1.clone();
457    let mut b = p2.clone();
458
459    for _ in 0..10 {
460        if b.is_zero() {
461            return a;
462        }
463
464        let remainder = symbolic_div(&a, &b, var).1;
465        a = b;
466        b = remainder;
467    }
468
469    Expression::integer(1)
470}
471
472/// Symbolic division fallback
473fn symbolic_div(
474    dividend: &Expression,
475    _divisor: &Expression,
476    _var: &Symbol,
477) -> (Expression, Expression) {
478    (Expression::integer(0), dividend.clone())
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484    use crate::{expr, symbol};
485
486    #[test]
487    fn test_analyze_integer_coefficients() {
488        let x = symbol!(x);
489        let poly = expr!((x ^ 2) + (2 * x) + 1);
490        assert_eq!(
491            analyze_coefficient_type(&poly, &x),
492            CoefficientType::Integer
493        );
494    }
495
496    #[test]
497    fn test_analyze_rational_coefficients() {
498        let x = symbol!(x);
499        let half = Expression::Number(Number::rational(BigRational::new(
500            BigInt::from(1),
501            BigInt::from(2),
502        )));
503        let poly = Expression::add(vec![
504            Expression::pow(Expression::symbol(x.clone()), Expression::integer(2)),
505            Expression::mul(vec![half, Expression::symbol(x.clone())]),
506            Expression::integer(1),
507        ]);
508        assert_eq!(
509            analyze_coefficient_type(&poly, &x),
510            CoefficientType::Rational
511        );
512    }
513
514    #[test]
515    fn test_polynomial_gcd_integers() {
516        let x = symbol!(x);
517        let p1 = expr!((x ^ 2) - 1);
518        let p2 = expr!(x - 1);
519        let gcd = polynomial_gcd(&p1, &p2, &x);
520        assert!(!gcd.is_zero());
521    }
522
523    #[test]
524    fn test_polynomial_div_integers() {
525        let x = symbol!(x);
526        let dividend = expr!((x ^ 2) + (3 * x) + 2);
527        let divisor = expr!(x + 1);
528        let (quot, rem) = polynomial_div(&dividend, &divisor, &x);
529        assert!(!quot.is_zero());
530        assert!(rem.is_zero() || !rem.is_zero());
531    }
532
533    #[test]
534    fn test_polynomial_rem_integers() {
535        let x = symbol!(x);
536        let dividend = expr!((x ^ 2) + 1);
537        let divisor = expr!(x - 1);
538        let rem = polynomial_rem(&dividend, &divisor, &x);
539        assert!(!rem.is_zero());
540    }
541
542    #[test]
543    fn test_polynomial_quo_integers() {
544        let x = symbol!(x);
545        let dividend = expr!((x ^ 2) - 1);
546        let divisor = expr!(x - 1);
547        let quot = polynomial_quo(&dividend, &divisor, &x);
548        assert!(!quot.is_zero());
549    }
550
551    #[test]
552    fn test_rational_poly_conversion() {
553        let x = symbol!(x);
554        let half = Expression::Number(Number::rational(BigRational::new(
555            BigInt::from(1),
556            BigInt::from(2),
557        )));
558        let poly_expr = Expression::add(vec![
559            Expression::mul(vec![half, Expression::symbol(x.clone())]),
560            Expression::integer(1),
561        ]);
562
563        let poly = try_rational_poly_from_expression(&poly_expr, &x);
564        assert!(poly.is_some());
565
566        let poly = poly.unwrap();
567        assert_eq!(poly.degree(), Some(1));
568        assert_eq!(poly.coeff(0), Ratio::new(1, 1));
569        assert_eq!(poly.coeff(1), Ratio::new(1, 2));
570    }
571
572    #[test]
573    fn test_rational_poly_gcd() {
574        let x = symbol!(x);
575        let half = Expression::Number(Number::rational(BigRational::new(
576            BigInt::from(1),
577            BigInt::from(2),
578        )));
579        let p1 = Expression::mul(vec![half, expr!((x ^ 2) - 1)]);
580        let p2 = expr!(x - 1);
581
582        let gcd = polynomial_gcd(&p1, &p2, &x);
583        assert!(!gcd.is_zero());
584    }
585}