1use crate::kernel::{
2 domain::Domain,
3 expr::{BigFloat, BigInt, BigRat, ExprData, ExprId},
4};
5use std::fmt;
6
7pub const POS_INFINITY_SYMBOL: &str = "\u{221e}";
9
10#[cfg(feature = "parallel")]
30use dashmap::DashMap;
31
32#[cfg(not(feature = "parallel"))]
33use std::collections::HashMap;
34
35#[cfg(not(feature = "parallel"))]
36use std::sync::Mutex;
37
38#[cfg(feature = "parallel")]
43struct PoolIndex(DashMap<ExprData, ExprId>);
44
45#[cfg(not(feature = "parallel"))]
46struct PoolIndex(HashMap<ExprData, ExprId>);
47
48#[cfg(feature = "parallel")]
49impl PoolIndex {
50 fn new() -> Self {
51 PoolIndex(DashMap::new())
52 }
53 fn get(&self, data: &ExprData) -> Option<ExprId> {
54 self.0.get(data).map(|v| *v)
55 }
56 fn or_insert_with(&self, key: ExprData, f: impl FnOnce() -> ExprId) -> ExprId {
60 *self.0.entry(key).or_insert_with(f)
61 }
62}
63
64#[cfg(not(feature = "parallel"))]
65impl PoolIndex {
66 fn new() -> Self {
67 PoolIndex(HashMap::new())
68 }
69 fn get(&self, data: &ExprData) -> Option<ExprId> {
70 self.0.get(data).copied()
71 }
72 fn insert(&mut self, data: ExprData, id: ExprId) {
73 self.0.insert(data, id);
74 }
75}
76
77pub struct ExprPool {
86 nodes: boxcar::Vec<ExprData>,
88 #[cfg(feature = "parallel")]
90 index: PoolIndex,
91 #[cfg(not(feature = "parallel"))]
92 index: Mutex<PoolIndex>,
93}
94
95unsafe impl Send for ExprPool {}
96unsafe impl Sync for ExprPool {}
97
98impl ExprPool {
99 pub fn new() -> Self {
100 ExprPool {
101 nodes: boxcar::Vec::new(),
102 #[cfg(feature = "parallel")]
103 index: PoolIndex::new(),
104 #[cfg(not(feature = "parallel"))]
105 index: Mutex::new(PoolIndex::new()),
106 }
107 }
108
109 pub fn intern(&self, data: ExprData) -> ExprId {
112 #[cfg(feature = "parallel")]
113 {
114 if let Some(id) = self.index.get(&data) {
116 return id;
117 }
118 self.index
122 .or_insert_with(data.clone(), || ExprId(self.nodes.push(data) as u32))
123 }
124
125 #[cfg(not(feature = "parallel"))]
126 {
127 let mut idx = self.index.lock().expect("ExprPool index Mutex poisoned");
128 if let Some(id) = idx.get(&data) {
129 return id;
130 }
131 let id = ExprId(self.nodes.push(data.clone()) as u32);
132 idx.insert(data, id);
133 id
134 }
135 }
136
137 pub fn with<R, F: FnOnce(&ExprData) -> R>(&self, id: ExprId, f: F) -> R {
139 f(self
140 .nodes
141 .get(id.0 as usize)
142 .expect("ExprPool: ExprId out of range"))
143 }
144
145 pub fn get(&self, id: ExprId) -> ExprData {
147 self.with(id, |d| d.clone())
148 }
149
150 pub fn len(&self) -> usize {
152 self.nodes.count()
153 }
154
155 pub fn is_empty(&self) -> bool {
156 self.nodes.is_empty()
157 }
158
159 pub fn symbol(&self, name: impl Into<String>, domain: Domain) -> ExprId {
165 self.symbol_commutative(name, domain, true)
166 }
167
168 pub const IMAGINARY_UNIT_NAME: &'static str = "I";
175
176 pub fn imaginary_unit(&self) -> ExprId {
190 self.symbol(Self::IMAGINARY_UNIT_NAME, Domain::Complex)
191 }
192
193 pub fn is_imaginary_unit(&self, id: ExprId) -> bool {
197 self.with(id, |d| {
198 matches!(
199 d,
200 ExprData::Symbol { name, domain, .. }
201 if name == Self::IMAGINARY_UNIT_NAME && *domain == Domain::Complex
202 )
203 })
204 }
205
206 pub fn symbol_commutative(
209 &self,
210 name: impl Into<String>,
211 domain: Domain,
212 commutative: bool,
213 ) -> ExprId {
214 self.intern(ExprData::Symbol {
215 name: name.into(),
216 domain,
217 commutative,
218 })
219 }
220
221 pub fn integer(&self, n: impl Into<rug::Integer>) -> ExprId {
222 self.intern(ExprData::Integer(BigInt(n.into())))
223 }
224
225 pub fn rational(
226 &self,
227 numer: impl Into<rug::Integer>,
228 denom: impl Into<rug::Integer>,
229 ) -> ExprId {
230 let r = rug::Rational::from((numer.into(), denom.into()));
231 self.intern(ExprData::Rational(BigRat(r)))
232 }
233
234 pub fn float(&self, value: f64, prec: u32) -> ExprId {
235 let f = rug::Float::with_val(prec, value);
236 self.intern(ExprData::Float(BigFloat { inner: f, prec }))
237 }
238
239 pub fn add(&self, mut args: Vec<ExprId>) -> ExprId {
244 args.sort_unstable();
249 self.intern(ExprData::Add(args))
250 }
251
252 pub fn mul(&self, mut args: Vec<ExprId>) -> ExprId {
253 let sort_ok = args
255 .iter()
256 .all(|&a| crate::kernel::expr_props::mult_tree_is_commutative(self, a));
257 if sort_ok {
258 args.sort_unstable();
259 }
260 self.intern(ExprData::Mul(args))
261 }
262
263 pub fn pow(&self, base: ExprId, exp: ExprId) -> ExprId {
264 self.intern(ExprData::Pow { base, exp })
265 }
266
267 pub fn func(&self, name: impl Into<String>, args: Vec<ExprId>) -> ExprId {
268 self.intern(ExprData::Func {
269 name: name.into(),
270 args,
271 })
272 }
273
274 pub fn piecewise(&self, branches: Vec<(ExprId, ExprId)>, default: ExprId) -> ExprId {
284 self.intern(ExprData::Piecewise { branches, default })
285 }
286
287 pub fn predicate(&self, kind: crate::kernel::expr::PredicateKind, args: Vec<ExprId>) -> ExprId {
289 self.intern(ExprData::Predicate { kind, args })
290 }
291
292 pub fn pred_lt(&self, a: ExprId, b: ExprId) -> ExprId {
294 self.predicate(crate::kernel::expr::PredicateKind::Lt, vec![a, b])
295 }
296 pub fn pred_le(&self, a: ExprId, b: ExprId) -> ExprId {
297 self.predicate(crate::kernel::expr::PredicateKind::Le, vec![a, b])
298 }
299 pub fn pred_gt(&self, a: ExprId, b: ExprId) -> ExprId {
300 self.predicate(crate::kernel::expr::PredicateKind::Gt, vec![a, b])
301 }
302 pub fn pred_ge(&self, a: ExprId, b: ExprId) -> ExprId {
303 self.predicate(crate::kernel::expr::PredicateKind::Ge, vec![a, b])
304 }
305 pub fn pred_eq(&self, a: ExprId, b: ExprId) -> ExprId {
306 self.predicate(crate::kernel::expr::PredicateKind::Eq, vec![a, b])
307 }
308 pub fn pred_ne(&self, a: ExprId, b: ExprId) -> ExprId {
309 self.predicate(crate::kernel::expr::PredicateKind::Ne, vec![a, b])
310 }
311 pub fn pred_and(&self, args: Vec<ExprId>) -> ExprId {
312 self.predicate(crate::kernel::expr::PredicateKind::And, args)
313 }
314 pub fn pred_or(&self, args: Vec<ExprId>) -> ExprId {
315 self.predicate(crate::kernel::expr::PredicateKind::Or, args)
316 }
317 pub fn pred_not(&self, a: ExprId) -> ExprId {
318 self.predicate(crate::kernel::expr::PredicateKind::Not, vec![a])
319 }
320 pub fn pred_true(&self) -> ExprId {
321 self.predicate(crate::kernel::expr::PredicateKind::True, vec![])
322 }
323 pub fn pred_false(&self) -> ExprId {
324 self.predicate(crate::kernel::expr::PredicateKind::False, vec![])
325 }
326
327 pub fn forall(&self, var: ExprId, body: ExprId) -> ExprId {
330 self.intern(ExprData::Forall { var, body })
331 }
332
333 pub fn exists(&self, var: ExprId, body: ExprId) -> ExprId {
335 self.intern(ExprData::Exists { var, body })
336 }
337
338 pub fn root_sum(&self, poly: ExprId, var: ExprId, body: ExprId) -> ExprId {
340 self.intern(ExprData::RootSum { poly, var, body })
341 }
342
343 pub fn big_o(&self, arg: ExprId) -> ExprId {
345 self.intern(ExprData::BigO(arg))
346 }
347
348 pub fn pos_infinity(&self) -> ExprId {
350 self.symbol(POS_INFINITY_SYMBOL, Domain::Positive)
351 }
352
353 pub fn display(&self, id: ExprId) -> ExprDisplay<'_> {
358 ExprDisplay { id, pool: self }
359 }
360}
361
362impl Default for ExprPool {
363 fn default() -> Self {
364 Self::new()
365 }
366}
367
368pub struct ExprDisplay<'a> {
374 pub id: ExprId,
375 pub pool: &'a ExprPool,
376}
377
378impl fmt::Display for ExprDisplay<'_> {
379 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
380 let data = self.pool.get(self.id);
381 fmt_data(&data, self.pool, f)
382 }
383}
384
385impl fmt::Debug for ExprDisplay<'_> {
386 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
387 write!(f, "{}", self)
388 }
389}
390
391fn fmt_data(data: &ExprData, pool: &ExprPool, f: &mut fmt::Formatter<'_>) -> fmt::Result {
392 match data {
393 ExprData::Symbol { name, .. } => write!(f, "{}", name),
394 ExprData::Integer(n) => write!(f, "{}", n),
395 ExprData::Rational(r) => write!(f, "{}", r),
396 ExprData::Float(fl) => write!(f, "{}", fl),
397 ExprData::Add(args) => {
398 write!(f, "(")?;
399 for (i, &arg) in args.iter().enumerate() {
400 if i > 0 {
401 write!(f, " + ")?;
402 }
403 write!(f, "{}", pool.display(arg))?;
404 }
405 write!(f, ")")
406 }
407 ExprData::Mul(args) => {
408 write!(f, "(")?;
409 for (i, &arg) in args.iter().enumerate() {
410 if i > 0 {
411 write!(f, " * ")?;
412 }
413 write!(f, "{}", pool.display(arg))?;
414 }
415 write!(f, ")")
416 }
417 ExprData::Pow { base, exp } => {
418 write!(f, "{}^{}", pool.display(*base), pool.display(*exp))
419 }
420 ExprData::Func { name, args } => {
421 write!(f, "{}(", name)?;
422 for (i, &arg) in args.iter().enumerate() {
423 if i > 0 {
424 write!(f, ", ")?;
425 }
426 write!(f, "{}", pool.display(arg))?;
427 }
428 write!(f, ")")
429 }
430 ExprData::Piecewise { branches, default } => {
431 write!(f, "Piecewise(")?;
432 for (i, (cond, val)) in branches.iter().enumerate() {
433 if i > 0 {
434 write!(f, ", ")?;
435 }
436 write!(f, "({}, {})", pool.display(*cond), pool.display(*val))?;
437 }
438 write!(f, "; default={})", pool.display(*default))
439 }
440 ExprData::Predicate { kind, args } => match kind {
441 crate::kernel::expr::PredicateKind::True => write!(f, "True"),
442 crate::kernel::expr::PredicateKind::False => write!(f, "False"),
443 crate::kernel::expr::PredicateKind::Not => {
444 write!(f, "¬({})", pool.display(args[0]))
445 }
446 crate::kernel::expr::PredicateKind::And | crate::kernel::expr::PredicateKind::Or => {
447 write!(f, "(")?;
448 for (i, &arg) in args.iter().enumerate() {
449 if i > 0 {
450 write!(f, " {} ", kind)?;
451 }
452 write!(f, "{}", pool.display(arg))?;
453 }
454 write!(f, ")")
455 }
456 _ => {
457 write!(
458 f,
459 "({} {} {})",
460 pool.display(args[0]),
461 kind,
462 pool.display(args[1])
463 )
464 }
465 },
466 ExprData::Forall { var, body } => {
467 write!(f, "∀ {} . {}", pool.display(*var), pool.display(*body))
468 }
469 ExprData::Exists { var, body } => {
470 write!(f, "∃ {} . {}", pool.display(*var), pool.display(*body))
471 }
472 ExprData::BigO(arg) => {
473 write!(f, "O({})", pool.display(*arg))
474 }
475 ExprData::RootSum { poly, var, body } => {
476 write!(
477 f,
478 "RootSum({}, {} . {})",
479 pool.display(*poly),
480 pool.display(*var),
481 pool.display(*body)
482 )
483 }
484 }
485}
486
487#[cfg(test)]
492mod tests {
493 use super::*;
494 use crate::kernel::domain::Domain;
495
496 fn pool() -> ExprPool {
497 ExprPool::new()
498 }
499
500 #[test]
501 fn noncommutative_mul_orders_distinct() {
502 let p = pool();
503 let a = p.symbol_commutative("A", Domain::Real, false);
504 let b = p.symbol_commutative("B", Domain::Real, false);
505 assert_ne!(
506 p.mul(vec![a, b]),
507 p.mul(vec![b, a]),
508 "A*B and B*A must not hash-cons together for NC symbols"
509 );
510 }
511
512 #[test]
513 fn symbol_commutative_is_structural() {
514 let p = pool();
515 let xc = p.symbol_commutative("x", Domain::Real, true);
516 let xnc = p.symbol_commutative("x", Domain::Real, false);
517 assert_ne!(xc, xnc);
518 }
519
520 #[test]
523 fn symbol_interning() {
524 let p = pool();
525 let x1 = p.symbol("x", Domain::Real);
526 let x2 = p.symbol("x", Domain::Real);
527 assert_eq!(x1, x2, "same symbol must return same ExprId");
528 }
529
530 #[test]
531 fn domain_is_structural() {
532 let p = pool();
533 let xr = p.symbol("x", Domain::Real);
534 let xc = p.symbol("x", Domain::Complex);
535 assert_ne!(xr, xc, "same name but different domain must be distinct");
536 }
537
538 #[test]
539 fn integer_interning() {
540 let p = pool();
541 let a = p.integer(42_i32);
542 let b = p.integer(42_i32);
543 let c = p.integer(99_i32);
544 assert_eq!(a, b);
545 assert_ne!(a, c);
546 }
547
548 #[test]
549 fn rational_canonical() {
550 let p = pool();
551 let r1 = p.rational(2_i32, 4_i32);
553 let r2 = p.rational(1_i32, 2_i32);
554 assert_eq!(r1, r2, "rationals must be reduced to canonical form");
555 }
556
557 #[test]
558 fn float_precision_is_structural() {
559 let p = pool();
560 let f53 = p.float(1.0, 53);
561 let f64_ = p.float(1.0, 64);
562 assert_ne!(
563 f53, f64_,
564 "same value but different precision is a different expr"
565 );
566 }
567
568 #[test]
571 fn subexpression_sharing() {
572 let p = pool();
573 let x = p.symbol("x", Domain::Real);
574 let two = p.integer(2_i32);
575
576 let xsq1 = p.pow(x, two);
578 let xsq2 = p.pow(x, two);
579 assert_eq!(xsq1, xsq2);
580
581 assert_eq!(p.len(), 3);
583 }
584
585 #[test]
586 fn add_interning() {
587 let p = pool();
588 let x = p.symbol("x", Domain::Real);
589 let y = p.symbol("y", Domain::Real);
590 let s1 = p.add(vec![x, y]);
591 let s2 = p.add(vec![x, y]);
592 assert_eq!(s1, s2);
593 }
594
595 #[test]
596 fn arg_order_is_canonical() {
597 let p = pool();
600 let x = p.symbol("x", Domain::Real);
601 let y = p.symbol("y", Domain::Real);
602 let s1 = p.add(vec![x, y]);
603 let s2 = p.add(vec![y, x]);
604 assert_eq!(s1, s2, "a+b and b+a must be the same expression after PA-3");
605 let m1 = p.mul(vec![x, y]);
606 let m2 = p.mul(vec![y, x]);
607 assert_eq!(m1, m2, "a*b and b*a must be the same expression after PA-3");
608 }
609
610 #[test]
611 fn func_interning() {
612 let p = pool();
613 let x = p.symbol("x", Domain::Real);
614 let s1 = p.func("sin", vec![x]);
615 let s2 = p.func("sin", vec![x]);
616 let c1 = p.func("cos", vec![x]);
617 assert_eq!(s1, s2);
618 assert_ne!(s1, c1);
619 }
620
621 #[test]
624 fn display_symbol() {
625 let p = pool();
626 let x = p.symbol("x", Domain::Real);
627 assert_eq!(p.display(x).to_string(), "x");
628 }
629
630 #[test]
631 fn display_integer() {
632 let p = pool();
633 let n = p.integer(42_i32);
634 assert_eq!(p.display(n).to_string(), "42");
635 }
636
637 #[test]
638 fn display_pow() {
639 let p = pool();
640 let x = p.symbol("x", Domain::Real);
641 let two = p.integer(2_i32);
642 let xsq = p.pow(x, two);
643 assert_eq!(p.display(xsq).to_string(), "x^2");
644 }
645
646 #[test]
647 fn display_add() {
648 let p = pool();
649 let x = p.symbol("x", Domain::Real);
650 let y = p.symbol("y", Domain::Real);
651 let s = p.add(vec![x, y]);
652 assert_eq!(p.display(s).to_string(), "(x + y)");
653 }
654
655 #[test]
656 fn display_func() {
657 let p = pool();
658 let x = p.symbol("x", Domain::Real);
659 let s = p.func("sin", vec![x]);
660 assert_eq!(p.display(s).to_string(), "sin(x)");
661 }
662
663 #[test]
664 fn display_nested() {
665 let p = pool();
666 let x = p.symbol("x", Domain::Real);
667 let two = p.integer(2_i32);
668 let xsq = p.pow(x, two);
669 let one = p.integer(1_i32);
670 let expr = p.add(vec![xsq, one]);
671 assert_eq!(p.display(expr).to_string(), "(x^2 + 1)");
672 }
673
674 fn assert_send_sync<T: Send + Sync>() {}
677
678 #[test]
679 fn pool_is_send_sync() {
680 assert_send_sync::<ExprPool>();
681 }
682}