mathhook_core/algebra/
multivariate_gcd.rs

1//! Multivariate polynomial GCD computation using evaluation-interpolation
2//!
3//! Implements the heuristic GCD algorithm (heugcd) from SymPy's euclidtools.py.
4//! This approach avoids the infinite recursion issues of content-primitive factorization
5//! by using integer evaluation and polynomial interpolation.
6//!
7//! # Algorithm Overview (from `[Liao95]`)
8//!
9//! For multivariate polynomials f, g in Z[x₁, ..., xₙ]:
10//! 1. Extract ground GCD (numeric content only - NON-RECURSIVE)
11//! 2. Evaluate both polynomials at integer point x₀ for main variable
12//! 3. Recursively compute GCD of resulting (n-1)-variate polynomials
13//! 4. Interpolate back to n-variate polynomial
14//! 5. Verify result by polynomial division
15//! 6. If verification fails, try new evaluation point (up to 6 attempts)
16//!
17//! # Mathematical Background
18//!
19//! The key insight is that GCD computation can be reduced dimension by dimension:
20//! - gcd(f(x,y), g(x,y)) at y=y₀ gives gcd(f(x,y₀), g(x,y₀))
21//! - The GCD is then reconstructed via polynomial interpolation
22//! - Verification ensures correctness (heuristic may fail, triggering retry)
23//!
24//! # References
25//!
26//! - `[Liao95]` Liao, Q. "Factoring multivariate polynomials over algebraic number fields"
27//! - SymPy polys/euclidtools.py: `dmp_zz_heu_gcd`, `dup_zz_heu_gcd`
28
29use crate::algebra::gcd::PolynomialGcd;
30use crate::algebra::polynomial_advanced::AdvancedPolynomial;
31use crate::core::{Expression, Number, Symbol};
32use crate::expr;
33use crate::simplify::Simplify;
34use num_traits::ToPrimitive;
35
36/// Maximum number of evaluation points to try before giving up
37const HEU_GCD_MAX_ATTEMPTS: usize = 6;
38
39/// Error type for heuristic GCD failures
40#[derive(Debug, Clone)]
41pub struct HeuristicGCDFailed;
42
43/// Compute GCD of multivariate polynomials using evaluation-interpolation
44///
45/// This is the main entry point. Uses the heuristic GCD algorithm which:
46/// 1. Extracts numeric content (ground GCD)
47/// 2. Evaluates at integer points
48/// 3. Recursively reduces dimension
49/// 4. Interpolates and verifies
50///
51/// # Arguments
52///
53/// * `poly1` - First polynomial expression
54/// * `poly2` - Second polynomial expression
55/// * `vars` - List of variables (in order of elimination)
56///
57/// # Returns
58///
59/// Returns the GCD expression. Falls back to 1 if heuristic fails.
60///
61/// # Examples
62///
63/// ```rust,ignore
64/// use mathhook_core::{symbol, expr};
65/// use mathhook_core::algebra::multivariate_gcd::multivariate_gcd;
66///
67/// let x = symbol!(x);
68/// let y = symbol!(y);
69///
70/// // gcd(2xy, 3xy) = xy
71/// let p1 = expr!(2 * x * y);
72/// let p2 = expr!(3 * x * y);
73/// let result = multivariate_gcd(&p1, &p2, &[x.clone(), y.clone()]);
74/// ```
75pub fn multivariate_gcd(poly1: &Expression, poly2: &Expression, vars: &[Symbol]) -> Expression {
76    // Handle trivial cases first
77    if let Some(result) = trivial_gcd(poly1, poly2, vars) {
78        return result;
79    }
80
81    // Try heuristic GCD algorithm
82    match multivariate_heu_gcd(poly1, poly2, vars) {
83        Ok((gcd, _, _)) => gcd,
84        Err(HeuristicGCDFailed) => {
85            // Fallback: try univariate GCD if single variable
86            if vars.len() == 1 {
87                univariate_gcd_euclidean(poly1, poly2, &vars[0])
88            } else {
89                // Last resort: return 1 (conservative, but correct)
90                Expression::integer(1)
91            }
92        }
93    }
94}
95
96/// Check for trivial GCD cases (following SymPy's _dmp_rr_trivial_gcd pattern)
97fn trivial_gcd(poly1: &Expression, poly2: &Expression, vars: &[Symbol]) -> Option<Expression> {
98    if poly1.is_zero() {
99        return Some(poly2.clone());
100    }
101    if poly2.is_zero() {
102        return Some(poly1.clone());
103    }
104    if poly1 == poly2 {
105        return Some(poly1.clone());
106    }
107    if poly1.is_one() || poly2.is_one() {
108        return Some(Expression::integer(1));
109    }
110    if is_constant(poly1, vars) && is_constant(poly2, vars) {
111        return Some(poly1.gcd(poly2));
112    }
113
114    None
115}
116
117/// Heuristic multivariate polynomial GCD in Z[X]
118///
119/// Implements SymPy's `dmp_zz_heu_gcd` algorithm:
120/// 1. Base case: univariate → use `univariate_heu_gcd`
121/// 2. Extract ground GCD (numeric content only)
122/// 3. Compute bounds for evaluation point
123/// 4. For each attempt:
124///    a. Evaluate at integer point
125///    b. Recursively compute GCD
126///    c. Interpolate result
127///    d. Verify by division
128/// 5. Return (gcd, cofactor1, cofactor2) or error
129fn multivariate_heu_gcd(
130    poly1: &Expression,
131    poly2: &Expression,
132    vars: &[Symbol],
133) -> Result<(Expression, Expression, Expression), HeuristicGCDFailed> {
134    // Base case: univariate
135    if vars.len() <= 1 {
136        if vars.is_empty() {
137            let g = poly1.gcd(poly2);
138            let cff = if g.is_one() {
139                poly1.clone()
140            } else {
141                divide_exact(poly1, &g)
142            };
143            let cfg = if g.is_one() {
144                poly2.clone()
145            } else {
146                divide_exact(poly2, &g)
147            };
148            return Ok((g, cff, cfg));
149        }
150        return univariate_heu_gcd(poly1, poly2, &vars[0]);
151    }
152
153    // Check trivial cases
154    if let Some(g) = trivial_gcd(poly1, poly2, vars) {
155        let cff = divide_exact(poly1, &g);
156        let cfg = divide_exact(poly2, &g);
157        return Ok((g, cff, cfg));
158    }
159
160    // Disjoint variables check: gcd(x+1, y+1) = 1 (they share no common variables)
161    // This is the key fix for coprime polynomials in different variables
162    if have_disjoint_variables(poly1, poly2) {
163        return Ok((expr!(1), poly1.clone(), poly2.clone()));
164    }
165
166    // Extract ground GCD (numeric content only - NON-RECURSIVE!)
167    let (ground_gcd, f, g) = extract_ground_gcd(poly1, poly2);
168
169    // Main variable is the first one
170    let main_var = &vars[0];
171    let remaining_vars: Vec<Symbol> = vars[1..].to_vec();
172
173    // Compute bounds for evaluation point
174    let f_norm = polynomial_max_norm(&f, vars);
175    let g_norm = polynomial_max_norm(&g, vars);
176
177    let b = 2 * f_norm.min(g_norm) + 29;
178    let mut x = compute_initial_eval_point(b, &f, &g, main_var);
179
180    // Try up to HEU_GCD_MAX_ATTEMPTS evaluation points
181    for _ in 0..HEU_GCD_MAX_ATTEMPTS {
182        // Evaluate polynomials at x for main variable
183        let ff = polynomial_evaluate_at(&f, main_var, x);
184        let gg = polynomial_evaluate_at(&g, main_var, x);
185
186        // Skip if either evaluation is zero
187        if !ff.is_zero() && !gg.is_zero() {
188            // Recursively compute GCD of evaluated polynomials
189            if let Ok((h_eval, cff_eval, cfg_eval)) =
190                multivariate_heu_gcd(&ff, &gg, &remaining_vars)
191            {
192                // Interpolate GCD back to full dimension
193                let mut h = polynomial_interpolate(&h_eval, x, main_var);
194                h = ground_primitive(&h, vars);
195
196                // Try to verify h is the correct GCD
197                if let Some(cff) = try_exact_division(&f, &h, main_var) {
198                    if let Some(cfg) = try_exact_division(&g, &h, main_var) {
199                        // Success! Multiply by ground GCD
200                        let result_gcd = if ground_gcd == 1 {
201                            h
202                        } else {
203                            Expression::mul(vec![Expression::integer(ground_gcd), h]).simplify()
204                        };
205                        return Ok((result_gcd, cff, cfg));
206                    }
207                }
208
209                // Try cofactor approach: interpolate cff, compute h = f/cff
210                let cff = polynomial_interpolate(&cff_eval, x, main_var);
211                if let Some(h2) = try_exact_division(&f, &cff, main_var) {
212                    if let Some(cfg) = try_exact_division(&g, &h2, main_var) {
213                        let result_gcd = if ground_gcd == 1 {
214                            h2
215                        } else {
216                            Expression::mul(vec![Expression::integer(ground_gcd), h2]).simplify()
217                        };
218                        return Ok((result_gcd, cff, cfg));
219                    }
220                }
221
222                // Try other cofactor: interpolate cfg, compute h = g/cfg
223                let cfg = polynomial_interpolate(&cfg_eval, x, main_var);
224                if let Some(h3) = try_exact_division(&g, &cfg, main_var) {
225                    if let Some(cff2) = try_exact_division(&f, &h3, main_var) {
226                        let result_gcd = if ground_gcd == 1 {
227                            h3
228                        } else {
229                            Expression::mul(vec![Expression::integer(ground_gcd), h3]).simplify()
230                        };
231                        return Ok((result_gcd, cff2, cfg));
232                    }
233                }
234            }
235        }
236
237        // Update evaluation point using SymPy's formula
238        x = update_eval_point(x);
239    }
240
241    Err(HeuristicGCDFailed)
242}
243
244/// Heuristic univariate polynomial GCD in ``Z[x]``
245///
246/// Implements SymPy's `dup_zz_heu_gcd` algorithm for single variable case.
247fn univariate_heu_gcd(
248    poly1: &Expression,
249    poly2: &Expression,
250    var: &Symbol,
251) -> Result<(Expression, Expression, Expression), HeuristicGCDFailed> {
252    // Trivial cases
253    if poly1.is_zero() {
254        return Ok((poly2.clone(), expr!(0), expr!(1)));
255    }
256    if poly2.is_zero() {
257        return Ok((poly1.clone(), expr!(1), expr!(0)));
258    }
259    if poly1.is_one() || poly2.is_one() {
260        return Ok((expr!(1), poly1.clone(), poly2.clone()));
261    }
262
263    // Both don't depend on var → numeric GCD
264    if !depends_on_var(poly1, var) && !depends_on_var(poly2, var) {
265        let g = poly1.gcd(poly2);
266        let cff = divide_exact(poly1, &g);
267        let cfg = divide_exact(poly2, &g);
268        return Ok((g, cff, cfg));
269    }
270
271    // One depends on var, one doesn't → coprime (GCD = 1)
272    // Example: gcd(x+1, 5) = 1 when considering variable x
273    if !depends_on_var(poly1, var) || !depends_on_var(poly2, var) {
274        return Ok((expr!(1), poly1.clone(), poly2.clone()));
275    }
276
277    let deg1 = poly1.polynomial_degree(var).unwrap_or(0);
278    let deg2 = poly2.polynomial_degree(var).unwrap_or(0);
279
280    // Extract numeric content
281    let (ground_gcd, f, g) = extract_ground_gcd(poly1, poly2);
282
283    // Both constant after content extraction
284    if deg1 == 0 && deg2 == 0 {
285        let gcd = Expression::integer(ground_gcd);
286        return Ok((gcd, f, g));
287    }
288
289    // Compute bounds
290    let f_norm = polynomial_max_norm(&f, std::slice::from_ref(var));
291    let g_norm = polynomial_max_norm(&g, std::slice::from_ref(var));
292
293    let b = 2 * f_norm.min(g_norm) + 29;
294    let mut x = compute_initial_eval_point(b, &f, &g, var);
295
296    for _attempt in 0..HEU_GCD_MAX_ATTEMPTS {
297        // Evaluate at x
298        let ff = evaluate_univariate(&f, var, x);
299        let gg = evaluate_univariate(&g, var, x);
300
301        if ff != 0 && gg != 0 {
302            // Integer GCD
303            let h_int = gcd_integers(ff, gg);
304            let cff_int = ff / h_int;
305            let cfg_int = gg / h_int;
306
307            // Interpolate back to polynomial
308            let mut h = univariate_interpolate(h_int, x, var);
309            h = univariate_primitive(&h, var);
310
311            // Verify by division
312            if let Some(cff) = try_exact_division(&f, &h, var) {
313                if let Some(cfg) = try_exact_division(&g, &h, var) {
314                    let result = if ground_gcd == 1 {
315                        h
316                    } else {
317                        Expression::mul(vec![Expression::integer(ground_gcd), h]).simplify()
318                    };
319                    return Ok((result, cff, cfg));
320                }
321            }
322
323            // Try cofactor interpolation
324            let cff = univariate_interpolate(cff_int, x, var);
325            if let Some(h2) = try_exact_division(&f, &cff, var) {
326                if let Some(cfg) = try_exact_division(&g, &h2, var) {
327                    let result = if ground_gcd == 1 {
328                        h2
329                    } else {
330                        Expression::mul(vec![Expression::integer(ground_gcd), h2]).simplify()
331                    };
332                    return Ok((result, cff, cfg));
333                }
334            }
335
336            let cfg = univariate_interpolate(cfg_int, x, var);
337            if let Some(h3) = try_exact_division(&g, &cfg, var) {
338                if let Some(cff2) = try_exact_division(&f, &h3, var) {
339                    let result = if ground_gcd == 1 {
340                        h3
341                    } else {
342                        Expression::mul(vec![Expression::integer(ground_gcd), h3]).simplify()
343                    };
344                    return Ok((result, cff2, cfg));
345                }
346            }
347        }
348
349        x = update_eval_point(x);
350    }
351
352    Err(HeuristicGCDFailed)
353}
354
355/// Extract ground GCD (numeric content only) from two polynomials
356///
357/// Returns (gcd, poly1/gcd, poly2/gcd) where gcd is the integer GCD
358/// of all numeric coefficients in both polynomials.
359///
360/// CRITICAL: This is NON-RECURSIVE - only extracts numeric content!
361fn extract_ground_gcd(poly1: &Expression, poly2: &Expression) -> (i64, Expression, Expression) {
362    let coeffs1 = collect_numeric_coefficients(poly1);
363    let coeffs2 = collect_numeric_coefficients(poly2);
364
365    // Compute GCD of all coefficients
366    let mut gcd = 0i64;
367    for c in coeffs1.iter().chain(coeffs2.iter()) {
368        gcd = gcd_integers(gcd, c.abs());
369        if gcd == 1 {
370            break;
371        }
372    }
373
374    if gcd <= 1 {
375        return (1, poly1.clone(), poly2.clone());
376    }
377
378    // Divide both polynomials by ground GCD
379    let f = divide_by_integer(poly1, gcd);
380    let g = divide_by_integer(poly2, gcd);
381
382    (gcd, f, g)
383}
384
385/// Collect all numeric coefficients from a polynomial expression
386fn collect_numeric_coefficients(expr: &Expression) -> Vec<i64> {
387    let mut coeffs = Vec::new();
388    collect_coeffs_recursive(expr, &mut coeffs);
389    if coeffs.is_empty() {
390        coeffs.push(1);
391    }
392    coeffs
393}
394
395fn collect_coeffs_recursive(expr: &Expression, coeffs: &mut Vec<i64>) {
396    match expr {
397        Expression::Number(Number::Integer(n)) => {
398            coeffs.push(*n);
399        }
400        Expression::Number(Number::Rational(r)) => {
401            if let Some(n) = r.numer().to_i64() {
402                coeffs.push(n);
403            }
404            if let Some(d) = r.denom().to_i64() {
405                coeffs.push(d);
406            }
407        }
408        Expression::Add(terms) => {
409            for term in terms.iter() {
410                collect_coeffs_recursive(term, coeffs);
411            }
412        }
413        Expression::Mul(factors) => {
414            for factor in factors.iter() {
415                collect_coeffs_recursive(factor, coeffs);
416            }
417        }
418        Expression::Pow(base, exp) => {
419            collect_coeffs_recursive(base, coeffs);
420            collect_coeffs_recursive(exp, coeffs);
421        }
422        _ => {}
423    }
424}
425
426/// Divide polynomial by an integer
427fn divide_by_integer(expr: &Expression, divisor: i64) -> Expression {
428    if divisor == 1 {
429        return expr.clone();
430    }
431
432    match expr {
433        Expression::Number(Number::Integer(n)) => {
434            if n % divisor == 0 {
435                Expression::integer(n / divisor)
436            } else {
437                expr.clone()
438            }
439        }
440        Expression::Add(terms) => {
441            let new_terms: Vec<Expression> = terms
442                .iter()
443                .map(|t| divide_by_integer(t, divisor))
444                .collect();
445            Expression::add(new_terms).simplify()
446        }
447        Expression::Mul(factors) => {
448            // Try to divide the first numeric factor
449            let mut divided = false;
450            let mut new_factors = Vec::new();
451            for factor in factors.iter() {
452                if !divided {
453                    if let Expression::Number(Number::Integer(n)) = factor {
454                        if n % divisor == 0 {
455                            let new_coeff = n / divisor;
456                            if new_coeff != 1 {
457                                new_factors.push(Expression::integer(new_coeff));
458                            }
459                            divided = true;
460                            continue;
461                        }
462                    }
463                }
464                new_factors.push(factor.clone());
465            }
466            if new_factors.is_empty() {
467                Expression::integer(1)
468            } else {
469                Expression::mul(new_factors).simplify()
470            }
471        }
472        _ => expr.clone(),
473    }
474}
475
476/// Compute maximum absolute value of coefficients (infinity norm)
477fn polynomial_max_norm(expr: &Expression, _vars: &[Symbol]) -> i64 {
478    let coeffs = collect_numeric_coefficients(expr);
479    coeffs.iter().map(|c| c.abs()).max().unwrap_or(1)
480}
481
482/// Compute initial evaluation point (SymPy formula)
483fn compute_initial_eval_point(b: i64, f: &Expression, g: &Expression, var: &Symbol) -> i64 {
484    let f_norm = polynomial_max_norm(f, std::slice::from_ref(var));
485    let g_norm = polynomial_max_norm(g, std::slice::from_ref(var));
486
487    let lc_f = leading_coeff_abs(f, var);
488    let lc_g = leading_coeff_abs(g, var);
489
490    let sqrt_b = (b as f64).sqrt() as i64;
491    let option1 = b.min(99 * sqrt_b);
492    let option2 = if lc_f > 0 && lc_g > 0 {
493        2 * (f_norm / lc_f).min(g_norm / lc_g) + 4
494    } else {
495        4
496    };
497
498    option1.max(option2).max(2)
499}
500
501/// Update evaluation point for next attempt (SymPy formula)
502fn update_eval_point(x: i64) -> i64 {
503    let sqrt_x = (x as f64).sqrt() as i64;
504    let sqrt_sqrt_x = (sqrt_x as f64).sqrt() as i64;
505    (73794 * x * sqrt_sqrt_x) / 27011
506}
507
508/// Get absolute value of leading coefficient
509fn leading_coeff_abs(expr: &Expression, var: &Symbol) -> i64 {
510    let lc = expr.polynomial_leading_coefficient(var);
511    match lc {
512        Expression::Number(Number::Integer(n)) => n.abs(),
513        _ => 1,
514    }
515}
516
517/// Evaluate polynomial at integer value for main variable
518///
519/// Substitutes var = value and simplifies to get a polynomial in remaining variables.
520pub fn polynomial_evaluate_at(poly: &Expression, var: &Symbol, value: i64) -> Expression {
521    substitute_var(poly, var, value).simplify()
522}
523
524fn substitute_var(expr: &Expression, var: &Symbol, value: i64) -> Expression {
525    match expr {
526        Expression::Symbol(s) if s == var => Expression::integer(value),
527        Expression::Symbol(_) => expr.clone(),
528        Expression::Number(_) => expr.clone(),
529        Expression::Constant(_) => expr.clone(),
530        Expression::Add(terms) => {
531            let new_terms: Vec<Expression> = terms
532                .iter()
533                .map(|t| substitute_var(t, var, value))
534                .collect();
535            Expression::add(new_terms)
536        }
537        Expression::Mul(factors) => {
538            let new_factors: Vec<Expression> = factors
539                .iter()
540                .map(|f| substitute_var(f, var, value))
541                .collect();
542            Expression::mul(new_factors)
543        }
544        Expression::Pow(base, exp) => {
545            let new_base = substitute_var(base, var, value);
546            let new_exp = substitute_var(exp, var, value);
547            Expression::pow(new_base, new_exp)
548        }
549        _ => expr.clone(),
550    }
551}
552
553/// Evaluate univariate polynomial at integer value, returning integer result
554fn evaluate_univariate(poly: &Expression, var: &Symbol, value: i64) -> i64 {
555    let result = polynomial_evaluate_at(poly, var, value);
556    match result {
557        Expression::Number(Number::Integer(n)) => n,
558        _ => 0,
559    }
560}
561
562/// Interpolate polynomial from integer using symmetric representation
563///
564/// Recovers polynomial coefficients from integer h using base x.
565/// Uses symmetric modular representation: if coeff > x/2, use coeff - x.
566pub fn polynomial_interpolate(h: &Expression, x: i64, var: &Symbol) -> Expression {
567    // If h is already a polynomial (multivariate case), apply interpolation recursively
568    match h {
569        Expression::Number(Number::Integer(n)) => univariate_interpolate(*n, x, var),
570        Expression::Add(terms) => {
571            let new_terms: Vec<Expression> = terms
572                .iter()
573                .map(|t| polynomial_interpolate(t, x, var))
574                .collect();
575            Expression::add(new_terms).simplify()
576        }
577        Expression::Mul(factors) => {
578            let new_factors: Vec<Expression> = factors
579                .iter()
580                .map(|f| polynomial_interpolate(f, x, var))
581                .collect();
582            Expression::mul(new_factors).simplify()
583        }
584        _ => h.clone(),
585    }
586}
587
588/// Interpolate univariate polynomial from integer
589///
590/// Converts integer h to polynomial using base x with symmetric representation.
591/// The resulting polynomial uses `var` as the variable.
592fn univariate_interpolate(mut h: i64, x: i64, var: &Symbol) -> Expression {
593    if h == 0 {
594        return expr!(0);
595    }
596
597    let mut coeffs = Vec::new();
598    let half_x = x / 2;
599
600    while h != 0 {
601        let mut coeff = h % x;
602        if coeff > half_x {
603            coeff -= x;
604        }
605        coeffs.push(coeff);
606        h = (h - coeff) / x;
607    }
608
609    if coeffs.is_empty() {
610        return expr!(0);
611    }
612
613    let mut terms = Vec::new();
614
615    for (power, &coeff) in coeffs.iter().enumerate() {
616        if coeff == 0 {
617            continue;
618        }
619        let term = if power == 0 {
620            Expression::integer(coeff)
621        } else if power == 1 {
622            if coeff == 1 {
623                Expression::symbol(var.clone())
624            } else {
625                Expression::mul(vec![
626                    Expression::integer(coeff),
627                    Expression::symbol(var.clone()),
628                ])
629            }
630        } else {
631            let power_expr = Expression::pow(
632                Expression::symbol(var.clone()),
633                Expression::integer(power as i64),
634            );
635            if coeff == 1 {
636                power_expr
637            } else {
638                Expression::mul(vec![Expression::integer(coeff), power_expr])
639            }
640        };
641        terms.push(term);
642    }
643
644    if terms.is_empty() {
645        expr!(0)
646    } else if terms.len() == 1 {
647        terms.pop().unwrap()
648    } else {
649        Expression::add(terms)
650    }
651}
652
653/// Compute primitive part (divide by content) for multivariate polynomial
654fn ground_primitive(poly: &Expression, _vars: &[Symbol]) -> Expression {
655    let coeffs = collect_numeric_coefficients(poly);
656    let mut content = 0i64;
657    for c in coeffs {
658        content = gcd_integers(content, c.abs());
659        if content == 1 {
660            break;
661        }
662    }
663
664    if content <= 1 {
665        return poly.clone();
666    }
667
668    divide_by_integer(poly, content)
669}
670
671/// Compute primitive part for univariate polynomial
672fn univariate_primitive(poly: &Expression, _var: &Symbol) -> Expression {
673    let coeffs = collect_numeric_coefficients(poly);
674    let mut content = 0i64;
675    for c in coeffs {
676        content = gcd_integers(content, c.abs());
677        if content == 1 {
678            break;
679        }
680    }
681
682    if content <= 1 {
683        return poly.clone();
684    }
685
686    divide_by_integer(poly, content)
687}
688
689/// Try exact polynomial division, returning Some(quotient) if successful
690fn try_exact_division(
691    dividend: &Expression,
692    divisor: &Expression,
693    _var: &Symbol,
694) -> Option<Expression> {
695    if divisor.is_one() {
696        return Some(dividend.clone());
697    }
698    if divisor.is_zero() {
699        return None;
700    }
701    if dividend == divisor {
702        return Some(expr!(1));
703    }
704    if dividend.is_zero() {
705        return Some(expr!(0));
706    }
707
708    // Try integer division first
709    if let (Expression::Number(Number::Integer(a)), Expression::Number(Number::Integer(b))) =
710        (dividend, divisor)
711    {
712        if *b != 0 && a % b == 0 {
713            return Some(Expression::integer(a / b));
714        } else {
715            return None;
716        }
717    }
718
719    // Try monomial division: coeff * vars / vars -> coeff (e.g., 2*y / y = 2)
720    if let Some(result) = try_monomial_division(dividend, divisor) {
721        return Some(result);
722    }
723
724    // Try sum division: (a + b) / c = a/c + b/c (if both exact)
725    if let Expression::Add(terms) = dividend {
726        let mut quotient_terms = Vec::new();
727        for term in terms.iter() {
728            if let Some(q) = try_exact_division(term, divisor, _var) {
729                quotient_terms.push(q);
730            } else {
731                return None;
732            }
733        }
734        let result = Expression::add(quotient_terms).simplify();
735        return Some(result);
736    }
737
738    None
739}
740
741/// Try to divide monomials directly
742/// Handles cases like: 2*y / y = 2, 6*x*y / (2*x) = 3*y
743fn try_monomial_division(dividend: &Expression, divisor: &Expression) -> Option<Expression> {
744    let (div_coeff, div_vars) = extract_coeff_and_vars(dividend);
745    let (sor_coeff, sor_vars) = extract_coeff_and_vars(divisor);
746
747    // Check coefficient divisibility
748    if sor_coeff == 0 {
749        return None;
750    }
751    if div_coeff % sor_coeff != 0 {
752        return None;
753    }
754    let result_coeff = div_coeff / sor_coeff;
755
756    // Check variable divisibility: each var in divisor must appear in dividend with >= power
757    let mut remaining_vars = div_vars;
758    for (var, power) in &sor_vars {
759        if let Some(div_power) = remaining_vars.get(var) {
760            if *div_power >= *power {
761                let new_power = div_power - power;
762                if new_power == 0 {
763                    remaining_vars.remove(var);
764                } else {
765                    remaining_vars.insert(var.clone(), new_power);
766                }
767            } else {
768                return None;
769            }
770        } else {
771            return None;
772        }
773    }
774
775    // Build result expression
776    let mut factors = Vec::new();
777    if result_coeff != 1 || remaining_vars.is_empty() {
778        factors.push(Expression::integer(result_coeff));
779    }
780    for (var, power) in remaining_vars {
781        if power == 1 {
782            factors.push(Expression::symbol(var));
783        } else {
784            factors.push(Expression::pow(
785                Expression::symbol(var),
786                Expression::integer(power),
787            ));
788        }
789    }
790
791    if factors.is_empty() {
792        Some(Expression::integer(1))
793    } else if factors.len() == 1 {
794        Some(factors.pop().unwrap())
795    } else {
796        Some(Expression::mul(factors).simplify())
797    }
798}
799
800/// Extract numeric coefficient and variable factors from a monomial expression
801/// Returns (coefficient, HashMap of variable -> power)
802fn extract_coeff_and_vars(expr: &Expression) -> (i64, std::collections::HashMap<Symbol, i64>) {
803    let mut coeff = 1i64;
804    let mut vars = std::collections::HashMap::new();
805
806    extract_coeff_and_vars_recursive(expr, &mut coeff, &mut vars);
807
808    (coeff, vars)
809}
810
811fn extract_coeff_and_vars_recursive(
812    expr: &Expression,
813    coeff: &mut i64,
814    vars: &mut std::collections::HashMap<Symbol, i64>,
815) {
816    match expr {
817        Expression::Number(Number::Integer(n)) => {
818            *coeff *= *n;
819        }
820        Expression::Symbol(s) => {
821            *vars.entry(s.clone()).or_insert(0) += 1;
822        }
823        Expression::Pow(base, exp) => {
824            if let (Expression::Symbol(s), Expression::Number(Number::Integer(e))) =
825                (base.as_ref(), exp.as_ref())
826            {
827                *vars.entry(s.clone()).or_insert(0) += *e;
828            }
829        }
830        Expression::Mul(factors) => {
831            for factor in factors.iter() {
832                extract_coeff_and_vars_recursive(factor, coeff, vars);
833            }
834        }
835        _ => {}
836    }
837}
838
839/// Divide exactly (assuming division is exact)
840fn divide_exact(dividend: &Expression, divisor: &Expression) -> Expression {
841    if divisor.is_one() {
842        return dividend.clone();
843    }
844    if divisor.is_zero() {
845        return dividend.clone();
846    }
847    if dividend == divisor {
848        return expr!(1);
849    }
850
851    if let (Expression::Number(Number::Integer(a)), Expression::Number(Number::Integer(b))) =
852        (dividend, divisor)
853    {
854        if *b != 0 && a % b == 0 {
855            return Expression::integer(a / b);
856        }
857    }
858
859    Expression::mul(vec![
860        dividend.clone(),
861        Expression::pow(divisor.clone(), Expression::integer(-1)),
862    ])
863    .simplify()
864}
865
866/// Integer GCD using Euclidean algorithm
867fn gcd_integers(mut a: i64, mut b: i64) -> i64 {
868    a = a.abs();
869    b = b.abs();
870    while b != 0 {
871        let t = b;
872        b = a % b;
873        a = t;
874    }
875    a.max(1)
876}
877
878/// Check if expression depends on a variable
879fn depends_on_var(expr: &Expression, var: &Symbol) -> bool {
880    match expr {
881        Expression::Symbol(s) => s == var,
882        Expression::Add(terms) | Expression::Mul(terms) => {
883            terms.iter().any(|t| depends_on_var(t, var))
884        }
885        Expression::Pow(base, exp) => depends_on_var(base, var) || depends_on_var(exp, var),
886        _ => false,
887    }
888}
889
890/// Check if expression is constant with respect to variables
891fn is_constant(expr: &Expression, vars: &[Symbol]) -> bool {
892    !vars.iter().any(|v| depends_on_var(expr, v))
893}
894
895/// Collect all symbols (free variables) from an expression
896fn collect_expression_symbols(expr: &Expression) -> std::collections::HashSet<Symbol> {
897    use std::collections::HashSet;
898    let mut symbols = HashSet::new();
899
900    fn collect(expr: &Expression, symbols: &mut HashSet<Symbol>) {
901        match expr {
902            Expression::Symbol(s) => {
903                symbols.insert(s.clone());
904            }
905            Expression::Add(terms) | Expression::Mul(terms) => {
906                for term in terms.iter() {
907                    collect(term, symbols);
908                }
909            }
910            Expression::Pow(base, exp) => {
911                collect(base, symbols);
912                collect(exp, symbols);
913            }
914            Expression::Function { args, .. } => {
915                for arg in args.iter() {
916                    collect(arg, symbols);
917                }
918            }
919            _ => {}
920        }
921    }
922
923    collect(expr, &mut symbols);
924    symbols
925}
926
927/// Check if two polynomials have disjoint variable sets
928/// If they share no common variables, their GCD is 1 (coprime)
929fn have_disjoint_variables(poly1: &Expression, poly2: &Expression) -> bool {
930    let vars1 = collect_expression_symbols(poly1);
931    let vars2 = collect_expression_symbols(poly2);
932    vars1.is_disjoint(&vars2)
933}
934
935/// Fallback: Euclidean GCD for univariate polynomials
936fn univariate_gcd_euclidean(poly1: &Expression, poly2: &Expression, var: &Symbol) -> Expression {
937    if poly1.is_one() || poly2.is_one() {
938        return expr!(1);
939    }
940    if !depends_on_var(poly1, var) && !depends_on_var(poly2, var) {
941        return poly1.gcd(poly2);
942    }
943
944    let mut a = poly1.clone();
945    let mut b = poly2.clone();
946
947    while !b.is_zero() {
948        let r = a.rem_polynomial(&b, var);
949        if r == a {
950            return expr!(1);
951        }
952        a = b;
953        b = r;
954    }
955
956    normalize_gcd(&a, var)
957}
958
959/// Normalize GCD to have positive leading coefficient
960fn normalize_gcd(poly: &Expression, var: &Symbol) -> Expression {
961    let lc = poly.polynomial_leading_coefficient(var);
962    if let Expression::Number(Number::Integer(n)) = lc {
963        if n < 0 {
964            return Expression::mul(vec![Expression::integer(-1), poly.clone()]).simplify();
965        }
966    }
967    poly.clone()
968}
969
970#[cfg(test)]
971mod tests {
972    use super::*;
973    use crate::{expr, symbol};
974
975    #[test]
976    fn test_gcd_integers() {
977        assert_eq!(gcd_integers(12, 18), 6);
978        assert_eq!(gcd_integers(15, 25), 5);
979        assert_eq!(gcd_integers(7, 11), 1);
980        assert_eq!(gcd_integers(0, 5), 5);
981        assert_eq!(gcd_integers(-12, 18), 6);
982    }
983
984    #[test]
985    fn test_extract_ground_gcd() {
986        let p1 = Expression::mul(vec![Expression::integer(6), expr!(x)]);
987        let p2 = Expression::mul(vec![Expression::integer(9), expr!(y)]);
988        let (gcd, _f, _g) = extract_ground_gcd(&p1, &p2);
989        assert_eq!(gcd, 3);
990    }
991
992    #[test]
993    fn test_polynomial_evaluate_at() {
994        let x = symbol!(x);
995        let poly = Expression::add(vec![
996            Expression::pow(expr!(x), Expression::integer(2)),
997            Expression::mul(vec![Expression::integer(2), expr!(x)]),
998            Expression::integer(1),
999        ]);
1000        let result = polynomial_evaluate_at(&poly, &x, 3);
1001        assert_eq!(result, Expression::integer(16));
1002    }
1003
1004    #[test]
1005    fn test_trivial_cases() {
1006        let x = symbol!(x);
1007        let y = symbol!(y);
1008        let vars = vec![x.clone(), y.clone()];
1009
1010        let p = expr!(x);
1011        assert_eq!(multivariate_gcd(&expr!(0), &p, &vars), p);
1012
1013        assert_eq!(multivariate_gcd(&p, &p, &vars), p);
1014
1015        assert_eq!(multivariate_gcd(&expr!(1), &p, &vars), expr!(1));
1016    }
1017
1018    #[test]
1019    fn test_constant_gcd() {
1020        let x = symbol!(x);
1021        let y = symbol!(y);
1022        let vars = vec![x.clone(), y.clone()];
1023
1024        let result = multivariate_gcd(&Expression::integer(6), &Expression::integer(9), &vars);
1025        assert_eq!(result, Expression::integer(3));
1026    }
1027
1028    #[test]
1029    fn test_coprime_polynomials() {
1030        let x = symbol!(x);
1031        let y = symbol!(y);
1032        let vars = vec![x.clone(), y.clone()];
1033
1034        let result = multivariate_gcd(&expr!(x), &expr!(y), &vars);
1035        assert!(result.is_one() || result == Expression::integer(1));
1036    }
1037
1038    #[test]
1039    fn test_bivariate_gcd_simple() {
1040        let x = symbol!(x);
1041        let y = symbol!(y);
1042        let vars = vec![x.clone(), y.clone()];
1043
1044        let p = Expression::mul(vec![expr!(x), expr!(y)]);
1045        let result = multivariate_gcd(&p, &p, &vars);
1046        assert_eq!(result, p);
1047    }
1048
1049    #[test]
1050    fn test_bivariate_gcd_content() {
1051        let x = symbol!(x);
1052        let y = symbol!(y);
1053        let vars = vec![x.clone(), y.clone()];
1054
1055        let p1 = Expression::mul(vec![Expression::integer(2), expr!(x), expr!(y)]);
1056        let p2 = Expression::mul(vec![Expression::integer(3), expr!(x), expr!(y)]);
1057        let result = multivariate_gcd(&p1, &p2, &vars);
1058
1059        assert!(!result.is_zero());
1060    }
1061
1062    #[test]
1063    fn test_bivariate_gcd_different_degrees() {
1064        let x = symbol!(x);
1065        let y = symbol!(y);
1066        let vars = vec![x.clone(), y.clone()];
1067
1068        let p1 = Expression::mul(vec![expr!(x ^ 2), expr!(y)]);
1069        let p2 = Expression::mul(vec![expr!(x), expr!(y ^ 2)]);
1070        let result = multivariate_gcd(&p1, &p2, &vars);
1071
1072        assert!(!result.is_zero());
1073    }
1074
1075    #[test]
1076    fn test_trivariate_gcd() {
1077        let x = symbol!(x);
1078        let y = symbol!(y);
1079        let z = symbol!(z);
1080        let vars = vec![x.clone(), y.clone(), z.clone()];
1081
1082        let p1 = Expression::mul(vec![expr!(x), expr!(y), expr!(z)]);
1083        let p2 = Expression::mul(vec![expr!(x), expr!(y)]);
1084        let result = multivariate_gcd(&p1, &p2, &vars);
1085
1086        assert!(!result.is_zero());
1087    }
1088
1089    #[test]
1090    fn test_univariate_interpolate() {
1091        let x = symbol!(x);
1092        let result = univariate_interpolate(5, 10, &x);
1093        assert_eq!(result, Expression::integer(5));
1094    }
1095
1096    #[test]
1097    fn test_zero_polynomial() {
1098        let x = symbol!(x);
1099        let y = symbol!(y);
1100        let vars = vec![x.clone(), y.clone()];
1101
1102        let p = Expression::mul(vec![expr!(x), expr!(y)]);
1103        let zero = expr!(0);
1104
1105        assert_eq!(multivariate_gcd(&p, &zero, &vars), p);
1106
1107        assert_eq!(multivariate_gcd(&zero, &p, &vars), p);
1108    }
1109}