1use super::error::ConversionError;
2use crate::flint::mpoly::{FlintMPoly, FlintMPolyCtx};
3use crate::kernel::{ExprData, ExprId, ExprPool};
4use std::collections::BTreeMap;
5use std::fmt;
6use std::ops::{Add, Mul, Neg, Sub};
7
8type Exponents = Vec<u32>;
15type TermMap = BTreeMap<Exponents, rug::Integer>;
16
17fn termmap_add(mut a: TermMap, b: TermMap) -> TermMap {
18 for (exp, coeff) in b {
19 let entry = a
20 .entry(exp.clone())
21 .or_insert_with(|| rug::Integer::from(0));
22 *entry += coeff;
23 if *entry == 0 {
24 a.remove(&exp);
25 }
26 }
27 a
28}
29
30fn termmap_mul(a: &TermMap, b: &TermMap) -> TermMap {
31 let mut result = TermMap::new();
32 for (ea, ca) in a {
33 for (eb, cb) in b {
34 let prod = ca.clone() * cb.clone();
35 if prod == 0 {
36 continue;
37 }
38 let len = ea.len().max(eb.len());
39 let mut exp = vec![0u32; len];
40 for (i, &e) in ea.iter().enumerate() {
41 exp[i] += e;
42 }
43 for (i, &e) in eb.iter().enumerate() {
44 exp[i] += e;
45 }
46 while exp.last() == Some(&0) {
48 exp.pop();
49 }
50 let entry = result
51 .entry(exp.clone())
52 .or_insert_with(|| rug::Integer::from(0));
53 *entry += prod;
54 if *entry == 0 {
55 result.remove(&exp);
56 }
57 }
58 }
59 result
60}
61
62fn termmap_neg(map: TermMap) -> TermMap {
63 map.into_iter().map(|(k, v)| (k, -v)).collect()
64}
65
66fn termmap_pow(base: &TermMap, n: u32) -> TermMap {
67 if n == 0 {
68 let mut one = TermMap::new();
69 one.insert(vec![], rug::Integer::from(1));
70 return one;
71 }
72 if n == 1 {
73 return base.clone();
74 }
75 let half = termmap_pow(base, n / 2);
76 let mut result = termmap_mul(&half, &half);
77 if n % 2 == 1 {
78 result = termmap_mul(&result, base);
79 }
80 result
81}
82
83fn expr_to_multivariate_coeffs(
84 expr: ExprId,
85 vars: &[ExprId],
86 pool: &ExprPool,
87) -> Result<TermMap, ConversionError> {
88 enum NodeInfo {
90 Symbol { idx: Option<usize>, name: String },
91 Integer(rug::Integer),
92 NonIntCoeff,
93 Add(Vec<ExprId>),
94 Mul(Vec<ExprId>),
95 Pow { base: ExprId, exp: ExprId },
96 Func(String),
97 }
98
99 let info = pool.with(expr, |data| match data {
100 ExprData::Symbol { name, .. } => NodeInfo::Symbol {
101 idx: vars.iter().position(|&v| v == expr),
102 name: name.clone(),
103 },
104 ExprData::Integer(n) => NodeInfo::Integer(n.0.clone()),
105 ExprData::Rational(_) | ExprData::Float(_) => NodeInfo::NonIntCoeff,
106 ExprData::Add(args) => NodeInfo::Add(args.clone()),
107 ExprData::Mul(args) => NodeInfo::Mul(args.clone()),
108 ExprData::Pow { base, exp } => NodeInfo::Pow {
109 base: *base,
110 exp: *exp,
111 },
112 ExprData::Func { name, .. } => NodeInfo::Func(name.clone()),
113 ExprData::Piecewise { .. }
114 | ExprData::Predicate { .. }
115 | ExprData::Forall { .. }
116 | ExprData::Exists { .. }
117 | ExprData::BigO(_) => NodeInfo::Func("piecewise_or_predicate".to_string()),
118 });
119
120 match info {
121 NodeInfo::Symbol { idx: Some(idx), .. } => {
122 let mut exp = vec![0u32; idx + 1];
123 exp[idx] = 1;
124 let mut map = TermMap::new();
125 map.insert(exp, rug::Integer::from(1));
126 Ok(map)
127 }
128 NodeInfo::Symbol { name, .. } => Err(ConversionError::UnexpectedSymbol(name)),
129 NodeInfo::Integer(n) => {
130 let mut map = TermMap::new();
131 if n != 0 {
132 map.insert(vec![], n);
133 }
134 Ok(map)
135 }
136 NodeInfo::NonIntCoeff => Err(ConversionError::NonIntegerCoefficient),
137 NodeInfo::Add(args) => {
138 let mut acc = TermMap::new();
139 for arg in args {
140 let sub = expr_to_multivariate_coeffs(arg, vars, pool)?;
141 acc = termmap_add(acc, sub);
142 }
143 Ok(acc)
144 }
145 NodeInfo::Mul(args) => {
146 let mut acc: TermMap = {
147 let mut m = TermMap::new();
148 m.insert(vec![], rug::Integer::from(1));
149 m
150 };
151 for arg in args {
152 let sub = expr_to_multivariate_coeffs(arg, vars, pool)?;
153 acc = termmap_mul(&acc, &sub);
154 }
155 Ok(acc)
156 }
157 NodeInfo::Pow { base, exp } => {
158 let n = pool
160 .with(exp, |data| match data {
161 ExprData::Integer(n) => Some(n.0.clone()),
162 _ => None,
163 })
164 .ok_or(ConversionError::NonConstantExponent)?;
165 if n < 0 {
166 return Err(ConversionError::NegativeExponent);
167 }
168 let n_u32 = n.to_u32().ok_or(ConversionError::ExponentTooLarge)?;
169 let base_coeffs = expr_to_multivariate_coeffs(base, vars, pool)?;
170 Ok(termmap_pow(&base_coeffs, n_u32))
171 }
172 NodeInfo::Func(name) => Err(ConversionError::NonPolynomialFunction(name)),
173 }
174}
175
176#[derive(Clone, PartialEq, Eq)]
186pub struct MultiPoly {
187 pub vars: Vec<ExprId>,
188 pub terms: TermMap,
189}
190
191impl MultiPoly {
192 pub fn zero(vars: Vec<ExprId>) -> Self {
193 MultiPoly {
194 vars,
195 terms: TermMap::new(),
196 }
197 }
198
199 pub fn constant(vars: Vec<ExprId>, c: i64) -> Self {
200 let mut terms = TermMap::new();
201 if c != 0 {
202 terms.insert(vec![], rug::Integer::from(c));
203 }
204 MultiPoly { vars, terms }
205 }
206
207 pub fn from_symbolic(
208 expr: ExprId,
209 vars: Vec<ExprId>,
210 pool: &ExprPool,
211 ) -> Result<Self, ConversionError> {
212 let terms = expr_to_multivariate_coeffs(expr, &vars, pool)?;
213 Ok(MultiPoly { vars, terms })
214 }
215
216 pub fn is_zero(&self) -> bool {
217 self.terms.is_empty()
218 }
219
220 pub fn total_degree(&self) -> u32 {
221 self.terms
222 .keys()
223 .map(|exp| exp.iter().sum::<u32>())
224 .max()
225 .unwrap_or(0)
226 }
227
228 pub fn integer_content(&self) -> rug::Integer {
230 self.terms.values().fold(rug::Integer::from(0), |acc, c| {
231 rug::Integer::from(acc.gcd_ref(c))
232 })
233 }
234
235 pub fn primitive_part(&self) -> Self {
237 let g = self.integer_content();
238 if g == 0 {
239 return self.clone();
240 }
241 self.div_integer(&g)
242 }
243
244 pub fn compatible_with(&self, other: &Self) -> bool {
246 self.vars == other.vars
247 }
248
249 pub fn gcd(&self, other: &Self) -> Option<Self> {
256 if !self.compatible_with(other) {
257 return None;
258 }
259 if self.is_zero() || other.is_zero() {
260 return None;
261 }
262
263 let nvars = self.vars.len();
264
265 let ctx = FlintMPolyCtx::new(nvars.max(1));
267
268 let a = multi_to_flint(self, &ctx);
269 let b = multi_to_flint(other, &ctx);
270
271 let g = a.gcd(&b, &ctx)?;
272
273 let terms = g.terms(nvars.max(1), &ctx);
275 let mut gcd = MultiPoly {
276 vars: self.vars.clone(),
277 terms,
278 };
279
280 if let Some((_, lc)) = gcd.terms.iter().next_back() {
282 if *lc < 0 {
283 gcd = -gcd;
284 }
285 }
286
287 Some(gcd)
288 }
289
290 pub fn to_expr(&self, pool: &ExprPool) -> ExprId {
295 if self.terms.is_empty() {
296 return pool.integer(0_i32);
297 }
298 let summands: Vec<ExprId> = self
299 .terms
300 .iter()
301 .map(|(exps, coeff)| {
302 let coeff_id = pool.integer(coeff.clone());
303 let mut factors = vec![coeff_id];
304 for (i, &e) in exps.iter().enumerate() {
305 if e == 0 || i >= self.vars.len() {
306 continue;
307 }
308 let var = self.vars[i];
309 let exp_id = pool.integer(e);
310 factors.push(if e == 1 { var } else { pool.pow(var, exp_id) });
311 }
312 match factors.len() {
313 0 => pool.integer(1_i32),
314 1 => factors[0],
315 _ => pool.mul(factors),
316 }
317 })
318 .collect();
319
320 match summands.len() {
321 0 => pool.integer(0_i32),
322 1 => summands[0],
323 _ => pool.add(summands),
324 }
325 }
326
327 pub fn div_integer(&self, d: &rug::Integer) -> Self {
329 debug_assert!(
330 self.terms.values().all(|v| v.is_divisible(d)),
331 "div_integer: not all coefficients are divisible by {d}"
332 );
333 let terms = self
334 .terms
335 .iter()
336 .map(|(k, v)| (k.clone(), rug::Integer::from(v.div_exact_ref(d))))
337 .collect();
338 MultiPoly {
339 vars: self.vars.clone(),
340 terms,
341 }
342 }
343}
344
345pub(crate) fn multi_to_flint_pub(p: &MultiPoly, ctx: &FlintMPolyCtx) -> FlintMPoly {
347 multi_to_flint(p, ctx)
348}
349
350fn multi_to_flint(p: &MultiPoly, ctx: &FlintMPolyCtx) -> FlintMPoly {
351 let nvars = p.vars.len().max(1);
352 let mut fp = FlintMPoly::new(ctx);
353 for (exp, coeff) in &p.terms {
354 let mut exp_u64 = vec![0u64; nvars];
355 for (i, &e) in exp.iter().enumerate() {
356 if i < nvars {
357 exp_u64[i] = e as u64;
358 }
359 }
360 fp.push_term(coeff, &exp_u64, ctx);
361 }
362 fp.finish(ctx);
363 fp
364}
365
366fn same_vars(a: &MultiPoly, b: &MultiPoly) {
367 assert_eq!(
368 a.vars, b.vars,
369 "MultiPoly arithmetic requires both operands to share the same variable list"
370 );
371}
372
373impl Neg for MultiPoly {
374 type Output = Self;
375 fn neg(self) -> Self {
376 MultiPoly {
377 vars: self.vars,
378 terms: termmap_neg(self.terms),
379 }
380 }
381}
382
383impl Add for MultiPoly {
384 type Output = Self;
385 fn add(self, rhs: Self) -> Self {
386 same_vars(&self, &rhs);
387 MultiPoly {
388 vars: self.vars.clone(),
389 terms: termmap_add(self.terms, rhs.terms),
390 }
391 }
392}
393
394impl Sub for MultiPoly {
395 type Output = Self;
396 fn sub(self, rhs: Self) -> Self {
397 same_vars(&self, &rhs);
398 MultiPoly {
399 vars: self.vars.clone(),
400 terms: termmap_add(self.terms, termmap_neg(rhs.terms)),
401 }
402 }
403}
404
405impl Mul for MultiPoly {
406 type Output = Self;
407 fn mul(self, rhs: Self) -> Self {
408 same_vars(&self, &rhs);
409 MultiPoly {
410 vars: self.vars.clone(),
411 terms: termmap_mul(&self.terms, &rhs.terms),
412 }
413 }
414}
415
416impl fmt::Display for MultiPoly {
417 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
418 if self.is_zero() {
419 return write!(f, "0");
420 }
421 let mut first = true;
422 for (exp, coeff) in &self.terms {
424 if !first {
425 if *coeff > 0 {
426 write!(f, " + ")?;
427 } else {
428 write!(f, " - ")?;
429 }
430 } else if *coeff < 0 {
431 write!(f, "-")?;
432 }
433 first = false;
434
435 let abs_coeff = rug::Integer::from(coeff.abs_ref());
436 let has_vars = exp.iter().any(|&e| e > 0);
437 if abs_coeff != 1 || !has_vars {
438 write!(f, "{abs_coeff}")?;
439 }
440 for (i, &e) in exp.iter().enumerate() {
441 if e == 0 {
442 continue;
443 }
444 let var_label = format!("x{i}");
446 if e == 1 {
447 write!(f, "{var_label}")?;
448 } else {
449 write!(f, "{var_label}^{e}")?;
450 }
451 }
452 }
453 Ok(())
454 }
455}
456
457impl fmt::Debug for MultiPoly {
458 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
459 write!(f, "MultiPoly(vars={:?}, {})", self.vars, self)
460 }
461}
462
463#[cfg(test)]
468mod tests {
469 use super::*;
470 use crate::kernel::{Domain, ExprPool};
471
472 fn pool_xy() -> (ExprPool, ExprId, ExprId) {
473 let p = ExprPool::new();
474 let x = p.symbol("x", Domain::Real);
475 let y = p.symbol("y", Domain::Real);
476 (p, x, y)
477 }
478
479 #[test]
480 fn univariate_from_symbolic() {
481 let (p, x, y) = pool_xy();
483 let xsq = p.pow(x, p.integer(2_i32));
484 let two_x = p.mul(vec![p.integer(2_i32), x]);
485 let expr = p.add(vec![xsq, two_x, p.integer(1_i32)]);
486 let poly = MultiPoly::from_symbolic(expr, vec![x, y], &p).unwrap();
487 assert_eq!(poly.terms[&vec![]], rug::Integer::from(1));
489 assert_eq!(poly.terms[&vec![1]], rug::Integer::from(2));
491 assert_eq!(poly.terms[&vec![2]], rug::Integer::from(1));
493 }
494
495 #[test]
496 fn bivariate_from_symbolic() {
497 let (p, x, y) = pool_xy();
499 let expr = p.mul(vec![x, y]);
500 let poly = MultiPoly::from_symbolic(expr, vec![x, y], &p).unwrap();
501 assert_eq!(poly.terms[&vec![1, 1]], rug::Integer::from(1));
502 assert_eq!(poly.terms.len(), 1);
503 }
504
505 #[test]
506 fn zero_poly() {
507 let (_p, x, y) = pool_xy();
508 let zero = MultiPoly::zero(vec![x, y]);
509 assert!(zero.is_zero());
510 }
511
512 #[test]
513 fn add_polys() {
514 let (p, x, y) = pool_xy();
515 let a = MultiPoly::from_symbolic(x, vec![x, y], &p).unwrap();
516 let b = MultiPoly::from_symbolic(y, vec![x, y], &p).unwrap();
517 let sum = a + b;
518 assert_eq!(sum.terms[&vec![1]], rug::Integer::from(1)); assert_eq!(sum.terms[&vec![0, 1]], rug::Integer::from(1)); }
521
522 #[test]
523 fn mul_polys() {
524 let (p, x, y) = pool_xy();
526 let a = MultiPoly::from_symbolic(p.add(vec![x, p.integer(1_i32)]), vec![x, y], &p).unwrap();
527 let b =
528 MultiPoly::from_symbolic(p.add(vec![x, p.integer(-1_i32)]), vec![x, y], &p).unwrap();
529 let prod = a * b;
530 assert_eq!(prod.terms[&vec![]], rug::Integer::from(-1));
531 assert_eq!(prod.terms[&vec![2]], rug::Integer::from(1));
532 assert!(!prod.terms.contains_key(&vec![1]));
533 }
534
535 #[test]
536 fn integer_content() {
537 let (p, x, y) = pool_xy();
539 let expr = p.add(vec![p.mul(vec![p.integer(6_i32), x]), p.integer(4_i32)]);
540 let poly = MultiPoly::from_symbolic(expr, vec![x, y], &p).unwrap();
541 assert_eq!(poly.integer_content(), rug::Integer::from(2));
542 }
543
544 #[test]
545 fn primitive_part() {
546 let (p, x, y) = pool_xy();
548 let expr = p.add(vec![p.mul(vec![p.integer(6_i32), x]), p.integer(4_i32)]);
549 let poly = MultiPoly::from_symbolic(expr, vec![x, y], &p).unwrap();
550 let pp = poly.primitive_part();
551 assert_eq!(pp.terms[&vec![]], rug::Integer::from(2));
552 assert_eq!(pp.terms[&vec![1]], rug::Integer::from(3));
553 }
554
555 #[test]
556 fn free_symbol_error() {
557 let p = ExprPool::new();
558 let x = p.symbol("x", Domain::Real);
559 let z = p.symbol("z", Domain::Real);
560 let expr = p.add(vec![x, z]);
561 assert!(matches!(
562 MultiPoly::from_symbolic(expr, vec![x], &p),
563 Err(ConversionError::UnexpectedSymbol(_))
564 ));
565 }
566}