Skip to main content

alkahest_cas/poly/
multipoly.rs

1use super::error::ConversionError;
2use crate::flint::mpoly::{FlintMPoly, FlintMPolyCtx};
3use crate::kernel::{ExprData, ExprId, ExprPool};
4use std::collections::BTreeMap;
5use std::fmt;
6use std::ops::{Add, Mul, Neg, Sub};
7
8// ---------------------------------------------------------------------------
9// Exponent vector: ascending by variable index.
10// Invariant: trailing zeros are stripped so that the zero polynomial has no
11// terms and the constant 1 has a single entry with key vec![].
12// ---------------------------------------------------------------------------
13
14type Exponents = Vec<u32>;
15type TermMap = BTreeMap<Exponents, rug::Integer>;
16
17fn termmap_add(mut a: TermMap, b: TermMap) -> TermMap {
18    for (exp, coeff) in b {
19        let entry = a
20            .entry(exp.clone())
21            .or_insert_with(|| rug::Integer::from(0));
22        *entry += coeff;
23        if *entry == 0 {
24            a.remove(&exp);
25        }
26    }
27    a
28}
29
30fn termmap_mul(a: &TermMap, b: &TermMap) -> TermMap {
31    let mut result = TermMap::new();
32    for (ea, ca) in a {
33        for (eb, cb) in b {
34            let prod = ca.clone() * cb.clone();
35            if prod == 0 {
36                continue;
37            }
38            let len = ea.len().max(eb.len());
39            let mut exp = vec![0u32; len];
40            for (i, &e) in ea.iter().enumerate() {
41                exp[i] += e;
42            }
43            for (i, &e) in eb.iter().enumerate() {
44                exp[i] += e;
45            }
46            // strip trailing zeros
47            while exp.last() == Some(&0) {
48                exp.pop();
49            }
50            let entry = result
51                .entry(exp.clone())
52                .or_insert_with(|| rug::Integer::from(0));
53            *entry += prod;
54            if *entry == 0 {
55                result.remove(&exp);
56            }
57        }
58    }
59    result
60}
61
62fn termmap_neg(map: TermMap) -> TermMap {
63    map.into_iter().map(|(k, v)| (k, -v)).collect()
64}
65
66fn termmap_pow(base: &TermMap, n: u32) -> TermMap {
67    if n == 0 {
68        let mut one = TermMap::new();
69        one.insert(vec![], rug::Integer::from(1));
70        return one;
71    }
72    if n == 1 {
73        return base.clone();
74    }
75    let half = termmap_pow(base, n / 2);
76    let mut result = termmap_mul(&half, &half);
77    if n % 2 == 1 {
78        result = termmap_mul(&result, base);
79    }
80    result
81}
82
83fn expr_to_multivariate_coeffs(
84    expr: ExprId,
85    vars: &[ExprId],
86    pool: &ExprPool,
87) -> Result<TermMap, ConversionError> {
88    // Extract node data in a single lock acquisition, then release before recursing.
89    enum NodeInfo {
90        Symbol { idx: Option<usize>, name: String },
91        Integer(rug::Integer),
92        NonIntCoeff,
93        Add(Vec<ExprId>),
94        Mul(Vec<ExprId>),
95        Pow { base: ExprId, exp: ExprId },
96        Func(String),
97    }
98
99    let info = pool.with(expr, |data| match data {
100        ExprData::Symbol { name, .. } => NodeInfo::Symbol {
101            idx: vars.iter().position(|&v| v == expr),
102            name: name.clone(),
103        },
104        ExprData::Integer(n) => NodeInfo::Integer(n.0.clone()),
105        ExprData::Rational(_) | ExprData::Float(_) => NodeInfo::NonIntCoeff,
106        ExprData::Add(args) => NodeInfo::Add(args.clone()),
107        ExprData::Mul(args) => NodeInfo::Mul(args.clone()),
108        ExprData::Pow { base, exp } => NodeInfo::Pow {
109            base: *base,
110            exp: *exp,
111        },
112        ExprData::Func { name, .. } => NodeInfo::Func(name.clone()),
113        ExprData::Piecewise { .. }
114        | ExprData::Predicate { .. }
115        | ExprData::Forall { .. }
116        | ExprData::Exists { .. }
117        | ExprData::BigO(_) => NodeInfo::Func("piecewise_or_predicate".to_string()),
118    });
119
120    match info {
121        NodeInfo::Symbol { idx: Some(idx), .. } => {
122            let mut exp = vec![0u32; idx + 1];
123            exp[idx] = 1;
124            let mut map = TermMap::new();
125            map.insert(exp, rug::Integer::from(1));
126            Ok(map)
127        }
128        NodeInfo::Symbol { name, .. } => Err(ConversionError::UnexpectedSymbol(name)),
129        NodeInfo::Integer(n) => {
130            let mut map = TermMap::new();
131            if n != 0 {
132                map.insert(vec![], n);
133            }
134            Ok(map)
135        }
136        NodeInfo::NonIntCoeff => Err(ConversionError::NonIntegerCoefficient),
137        NodeInfo::Add(args) => {
138            let mut acc = TermMap::new();
139            for arg in args {
140                let sub = expr_to_multivariate_coeffs(arg, vars, pool)?;
141                acc = termmap_add(acc, sub);
142            }
143            Ok(acc)
144        }
145        NodeInfo::Mul(args) => {
146            let mut acc: TermMap = {
147                let mut m = TermMap::new();
148                m.insert(vec![], rug::Integer::from(1));
149                m
150            };
151            for arg in args {
152                let sub = expr_to_multivariate_coeffs(arg, vars, pool)?;
153                acc = termmap_mul(&acc, &sub);
154            }
155            Ok(acc)
156        }
157        NodeInfo::Pow { base, exp } => {
158            // Read the exponent without holding the pool lock during recursion.
159            let n = pool
160                .with(exp, |data| match data {
161                    ExprData::Integer(n) => Some(n.0.clone()),
162                    _ => None,
163                })
164                .ok_or(ConversionError::NonConstantExponent)?;
165            if n < 0 {
166                return Err(ConversionError::NegativeExponent);
167            }
168            let n_u32 = n.to_u32().ok_or(ConversionError::ExponentTooLarge)?;
169            let base_coeffs = expr_to_multivariate_coeffs(base, vars, pool)?;
170            Ok(termmap_pow(&base_coeffs, n_u32))
171        }
172        NodeInfo::Func(name) => Err(ConversionError::NonPolynomialFunction(name)),
173    }
174}
175
176// ---------------------------------------------------------------------------
177// MultiPoly
178// ---------------------------------------------------------------------------
179
180/// Sparse multivariate polynomial over ℤ.
181///
182/// `vars` fixes the variable ordering; the exponent key `[e0, e1, …]` means
183/// `vars[0]^e0 * vars[1]^e1 * …`.  Trailing zeros in the exponent vector are
184/// always stripped so structural equality reduces to map equality.
185#[derive(Clone, PartialEq, Eq)]
186pub struct MultiPoly {
187    pub vars: Vec<ExprId>,
188    pub terms: TermMap,
189}
190
191impl MultiPoly {
192    pub fn zero(vars: Vec<ExprId>) -> Self {
193        MultiPoly {
194            vars,
195            terms: TermMap::new(),
196        }
197    }
198
199    pub fn constant(vars: Vec<ExprId>, c: i64) -> Self {
200        let mut terms = TermMap::new();
201        if c != 0 {
202            terms.insert(vec![], rug::Integer::from(c));
203        }
204        MultiPoly { vars, terms }
205    }
206
207    pub fn from_symbolic(
208        expr: ExprId,
209        vars: Vec<ExprId>,
210        pool: &ExprPool,
211    ) -> Result<Self, ConversionError> {
212        let terms = expr_to_multivariate_coeffs(expr, &vars, pool)?;
213        Ok(MultiPoly { vars, terms })
214    }
215
216    pub fn is_zero(&self) -> bool {
217        self.terms.is_empty()
218    }
219
220    pub fn total_degree(&self) -> u32 {
221        self.terms
222            .keys()
223            .map(|exp| exp.iter().sum::<u32>())
224            .max()
225            .unwrap_or(0)
226    }
227
228    /// GCD of all integer coefficients (content). Returns 0 for the zero polynomial.
229    pub fn integer_content(&self) -> rug::Integer {
230        self.terms.values().fold(rug::Integer::from(0), |acc, c| {
231            rug::Integer::from(acc.gcd_ref(c))
232        })
233    }
234
235    /// Primitive part: divide all coefficients by the integer content.
236    pub fn primitive_part(&self) -> Self {
237        let g = self.integer_content();
238        if g == 0 {
239            return self.clone();
240        }
241        self.div_integer(&g)
242    }
243
244    /// Returns `true` if both polynomials have the same variable list and can be combined.
245    pub fn compatible_with(&self, other: &Self) -> bool {
246        self.vars == other.vars
247    }
248
249    /// Compute the GCD of two compatible multivariate polynomials using FLINT.
250    ///
251    /// Returns `None` if the polynomials have different variable lists, if either
252    /// is zero, or if FLINT's GCD algorithm fails (which is exceedingly rare).
253    ///
254    /// The returned GCD is normalised so that its leading coefficient is positive.
255    pub fn gcd(&self, other: &Self) -> Option<Self> {
256        if !self.compatible_with(other) {
257            return None;
258        }
259        if self.is_zero() || other.is_zero() {
260            return None;
261        }
262
263        let nvars = self.vars.len();
264
265        // Build a FLINT context and convert both polynomials.
266        let ctx = FlintMPolyCtx::new(nvars.max(1));
267
268        let a = multi_to_flint(self, &ctx);
269        let b = multi_to_flint(other, &ctx);
270
271        let g = a.gcd(&b, &ctx)?;
272
273        // Convert the GCD back to MultiPoly
274        let terms = g.terms(nvars.max(1), &ctx);
275        let mut gcd = MultiPoly {
276            vars: self.vars.clone(),
277            terms,
278        };
279
280        // Normalise: make the leading coefficient positive
281        if let Some((_, lc)) = gcd.terms.iter().next_back() {
282            if *lc < 0 {
283                gcd = -gcd;
284            }
285        }
286
287        Some(gcd)
288    }
289
290    /// Convert back to a symbolic expression in the given pool.
291    ///
292    /// Produces a canonical sum-of-products: each term is `coeff * var[0]^e0 * var[1]^e1 * …`.
293    /// The zero polynomial maps to `Integer(0)`.
294    pub fn to_expr(&self, pool: &ExprPool) -> ExprId {
295        if self.terms.is_empty() {
296            return pool.integer(0_i32);
297        }
298        let summands: Vec<ExprId> = self
299            .terms
300            .iter()
301            .map(|(exps, coeff)| {
302                let coeff_id = pool.integer(coeff.clone());
303                let mut factors = vec![coeff_id];
304                for (i, &e) in exps.iter().enumerate() {
305                    if e == 0 || i >= self.vars.len() {
306                        continue;
307                    }
308                    let var = self.vars[i];
309                    let exp_id = pool.integer(e);
310                    factors.push(if e == 1 { var } else { pool.pow(var, exp_id) });
311                }
312                match factors.len() {
313                    0 => pool.integer(1_i32),
314                    1 => factors[0],
315                    _ => pool.mul(factors),
316                }
317            })
318            .collect();
319
320        match summands.len() {
321            0 => pool.integer(0_i32),
322            1 => summands[0],
323            _ => pool.add(summands),
324        }
325    }
326
327    /// Divide all coefficients by `d` (exact division — caller ensures divisibility).
328    pub fn div_integer(&self, d: &rug::Integer) -> Self {
329        debug_assert!(
330            self.terms.values().all(|v| v.is_divisible(d)),
331            "div_integer: not all coefficients are divisible by {d}"
332        );
333        let terms = self
334            .terms
335            .iter()
336            .map(|(k, v)| (k.clone(), rug::Integer::from(v.div_exact_ref(d))))
337            .collect();
338        MultiPoly {
339            vars: self.vars.clone(),
340            terms,
341        }
342    }
343}
344
345/// Convert a `MultiPoly` to a `FlintMPoly` in the given context.
346pub(crate) fn multi_to_flint_pub(p: &MultiPoly, ctx: &FlintMPolyCtx) -> FlintMPoly {
347    multi_to_flint(p, ctx)
348}
349
350fn multi_to_flint(p: &MultiPoly, ctx: &FlintMPolyCtx) -> FlintMPoly {
351    let nvars = p.vars.len().max(1);
352    let mut fp = FlintMPoly::new(ctx);
353    for (exp, coeff) in &p.terms {
354        let mut exp_u64 = vec![0u64; nvars];
355        for (i, &e) in exp.iter().enumerate() {
356            if i < nvars {
357                exp_u64[i] = e as u64;
358            }
359        }
360        fp.push_term(coeff, &exp_u64, ctx);
361    }
362    fp.finish(ctx);
363    fp
364}
365
366fn same_vars(a: &MultiPoly, b: &MultiPoly) {
367    assert_eq!(
368        a.vars, b.vars,
369        "MultiPoly arithmetic requires both operands to share the same variable list"
370    );
371}
372
373impl Neg for MultiPoly {
374    type Output = Self;
375    fn neg(self) -> Self {
376        MultiPoly {
377            vars: self.vars,
378            terms: termmap_neg(self.terms),
379        }
380    }
381}
382
383impl Add for MultiPoly {
384    type Output = Self;
385    fn add(self, rhs: Self) -> Self {
386        same_vars(&self, &rhs);
387        MultiPoly {
388            vars: self.vars.clone(),
389            terms: termmap_add(self.terms, rhs.terms),
390        }
391    }
392}
393
394impl Sub for MultiPoly {
395    type Output = Self;
396    fn sub(self, rhs: Self) -> Self {
397        same_vars(&self, &rhs);
398        MultiPoly {
399            vars: self.vars.clone(),
400            terms: termmap_add(self.terms, termmap_neg(rhs.terms)),
401        }
402    }
403}
404
405impl Mul for MultiPoly {
406    type Output = Self;
407    fn mul(self, rhs: Self) -> Self {
408        same_vars(&self, &rhs);
409        MultiPoly {
410            vars: self.vars.clone(),
411            terms: termmap_mul(&self.terms, &rhs.terms),
412        }
413    }
414}
415
416impl fmt::Display for MultiPoly {
417    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
418        if self.is_zero() {
419            return write!(f, "0");
420        }
421        let mut first = true;
422        // BTreeMap iterates in lexicographic key order (lowest degree first)
423        for (exp, coeff) in &self.terms {
424            if !first {
425                if *coeff > 0 {
426                    write!(f, " + ")?;
427                } else {
428                    write!(f, " - ")?;
429                }
430            } else if *coeff < 0 {
431                write!(f, "-")?;
432            }
433            first = false;
434
435            let abs_coeff = rug::Integer::from(coeff.abs_ref());
436            let has_vars = exp.iter().any(|&e| e > 0);
437            if abs_coeff != 1 || !has_vars {
438                write!(f, "{abs_coeff}")?;
439            }
440            for (i, &e) in exp.iter().enumerate() {
441                if e == 0 {
442                    continue;
443                }
444                // Use generic xi notation since we don't have ExprPool here
445                let var_label = format!("x{i}");
446                if e == 1 {
447                    write!(f, "{var_label}")?;
448                } else {
449                    write!(f, "{var_label}^{e}")?;
450                }
451            }
452        }
453        Ok(())
454    }
455}
456
457impl fmt::Debug for MultiPoly {
458    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
459        write!(f, "MultiPoly(vars={:?}, {})", self.vars, self)
460    }
461}
462
463// ---------------------------------------------------------------------------
464// Unit tests
465// ---------------------------------------------------------------------------
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470    use crate::kernel::{Domain, ExprPool};
471
472    fn pool_xy() -> (ExprPool, ExprId, ExprId) {
473        let p = ExprPool::new();
474        let x = p.symbol("x", Domain::Real);
475        let y = p.symbol("y", Domain::Real);
476        (p, x, y)
477    }
478
479    #[test]
480    fn univariate_from_symbolic() {
481        // x^2 + 2x + 1
482        let (p, x, y) = pool_xy();
483        let xsq = p.pow(x, p.integer(2_i32));
484        let two_x = p.mul(vec![p.integer(2_i32), x]);
485        let expr = p.add(vec![xsq, two_x, p.integer(1_i32)]);
486        let poly = MultiPoly::from_symbolic(expr, vec![x, y], &p).unwrap();
487        // constant term
488        assert_eq!(poly.terms[&vec![]], rug::Integer::from(1));
489        // x^1 term
490        assert_eq!(poly.terms[&vec![1]], rug::Integer::from(2));
491        // x^2 term
492        assert_eq!(poly.terms[&vec![2]], rug::Integer::from(1));
493    }
494
495    #[test]
496    fn bivariate_from_symbolic() {
497        // x*y
498        let (p, x, y) = pool_xy();
499        let expr = p.mul(vec![x, y]);
500        let poly = MultiPoly::from_symbolic(expr, vec![x, y], &p).unwrap();
501        assert_eq!(poly.terms[&vec![1, 1]], rug::Integer::from(1));
502        assert_eq!(poly.terms.len(), 1);
503    }
504
505    #[test]
506    fn zero_poly() {
507        let (_p, x, y) = pool_xy();
508        let zero = MultiPoly::zero(vec![x, y]);
509        assert!(zero.is_zero());
510    }
511
512    #[test]
513    fn add_polys() {
514        let (p, x, y) = pool_xy();
515        let a = MultiPoly::from_symbolic(x, vec![x, y], &p).unwrap();
516        let b = MultiPoly::from_symbolic(y, vec![x, y], &p).unwrap();
517        let sum = a + b;
518        assert_eq!(sum.terms[&vec![1]], rug::Integer::from(1)); // x
519        assert_eq!(sum.terms[&vec![0, 1]], rug::Integer::from(1)); // y
520    }
521
522    #[test]
523    fn mul_polys() {
524        // (x + 1) * (x - 1) = x^2 - 1
525        let (p, x, y) = pool_xy();
526        let a = MultiPoly::from_symbolic(p.add(vec![x, p.integer(1_i32)]), vec![x, y], &p).unwrap();
527        let b =
528            MultiPoly::from_symbolic(p.add(vec![x, p.integer(-1_i32)]), vec![x, y], &p).unwrap();
529        let prod = a * b;
530        assert_eq!(prod.terms[&vec![]], rug::Integer::from(-1));
531        assert_eq!(prod.terms[&vec![2]], rug::Integer::from(1));
532        assert!(!prod.terms.contains_key(&vec![1]));
533    }
534
535    #[test]
536    fn integer_content() {
537        // 6x + 4 → content = 2
538        let (p, x, y) = pool_xy();
539        let expr = p.add(vec![p.mul(vec![p.integer(6_i32), x]), p.integer(4_i32)]);
540        let poly = MultiPoly::from_symbolic(expr, vec![x, y], &p).unwrap();
541        assert_eq!(poly.integer_content(), rug::Integer::from(2));
542    }
543
544    #[test]
545    fn primitive_part() {
546        // 6x + 4 → primitive part = 3x + 2
547        let (p, x, y) = pool_xy();
548        let expr = p.add(vec![p.mul(vec![p.integer(6_i32), x]), p.integer(4_i32)]);
549        let poly = MultiPoly::from_symbolic(expr, vec![x, y], &p).unwrap();
550        let pp = poly.primitive_part();
551        assert_eq!(pp.terms[&vec![]], rug::Integer::from(2));
552        assert_eq!(pp.terms[&vec![1]], rug::Integer::from(3));
553    }
554
555    #[test]
556    fn free_symbol_error() {
557        let p = ExprPool::new();
558        let x = p.symbol("x", Domain::Real);
559        let z = p.symbol("z", Domain::Real);
560        let expr = p.add(vec![x, z]);
561        assert!(matches!(
562            MultiPoly::from_symbolic(expr, vec![x], &p),
563            Err(ConversionError::UnexpectedSymbol(_))
564        ));
565    }
566}