1use crate::kernel::{
2 domain::Domain,
3 expr::{BigFloat, BigInt, BigRat, ExprData, ExprId},
4};
5use std::fmt;
6
7pub const POS_INFINITY_SYMBOL: &str = "\u{221e}";
9
10#[cfg(feature = "parallel")]
30use dashmap::DashMap;
31
32#[cfg(not(feature = "parallel"))]
33use std::collections::HashMap;
34
35#[cfg(not(feature = "parallel"))]
36use std::sync::Mutex;
37
38#[cfg(feature = "parallel")]
43struct PoolIndex(DashMap<ExprData, ExprId>);
44
45#[cfg(not(feature = "parallel"))]
46struct PoolIndex(HashMap<ExprData, ExprId>);
47
48#[cfg(feature = "parallel")]
49impl PoolIndex {
50 fn new() -> Self {
51 PoolIndex(DashMap::new())
52 }
53 fn get(&self, data: &ExprData) -> Option<ExprId> {
54 self.0.get(data).map(|v| *v)
55 }
56 fn or_insert_with(&self, key: ExprData, f: impl FnOnce() -> ExprId) -> ExprId {
60 *self.0.entry(key).or_insert_with(f)
61 }
62}
63
64#[cfg(not(feature = "parallel"))]
65impl PoolIndex {
66 fn new() -> Self {
67 PoolIndex(HashMap::new())
68 }
69 fn get(&self, data: &ExprData) -> Option<ExprId> {
70 self.0.get(data).copied()
71 }
72 fn insert(&mut self, data: ExprData, id: ExprId) {
73 self.0.insert(data, id);
74 }
75}
76
77pub struct ExprPool {
86 nodes: boxcar::Vec<ExprData>,
88 #[cfg(feature = "parallel")]
90 index: PoolIndex,
91 #[cfg(not(feature = "parallel"))]
92 index: Mutex<PoolIndex>,
93}
94
95unsafe impl Send for ExprPool {}
96unsafe impl Sync for ExprPool {}
97
98impl ExprPool {
99 pub fn new() -> Self {
100 ExprPool {
101 nodes: boxcar::Vec::new(),
102 #[cfg(feature = "parallel")]
103 index: PoolIndex::new(),
104 #[cfg(not(feature = "parallel"))]
105 index: Mutex::new(PoolIndex::new()),
106 }
107 }
108
109 pub fn intern(&self, data: ExprData) -> ExprId {
112 #[cfg(feature = "parallel")]
113 {
114 if let Some(id) = self.index.get(&data) {
116 return id;
117 }
118 self.index
122 .or_insert_with(data.clone(), || ExprId(self.nodes.push(data) as u32))
123 }
124
125 #[cfg(not(feature = "parallel"))]
126 {
127 let mut idx = self.index.lock().expect("ExprPool index Mutex poisoned");
128 if let Some(id) = idx.get(&data) {
129 return id;
130 }
131 let id = ExprId(self.nodes.push(data.clone()) as u32);
132 idx.insert(data, id);
133 id
134 }
135 }
136
137 pub fn with<R, F: FnOnce(&ExprData) -> R>(&self, id: ExprId, f: F) -> R {
139 f(self
140 .nodes
141 .get(id.0 as usize)
142 .expect("ExprPool: ExprId out of range"))
143 }
144
145 pub fn get(&self, id: ExprId) -> ExprData {
147 self.with(id, |d| d.clone())
148 }
149
150 pub fn len(&self) -> usize {
152 self.nodes.count()
153 }
154
155 pub fn is_empty(&self) -> bool {
156 self.nodes.is_empty()
157 }
158
159 pub fn symbol(&self, name: impl Into<String>, domain: Domain) -> ExprId {
165 self.symbol_commutative(name, domain, true)
166 }
167
168 pub fn symbol_commutative(
171 &self,
172 name: impl Into<String>,
173 domain: Domain,
174 commutative: bool,
175 ) -> ExprId {
176 self.intern(ExprData::Symbol {
177 name: name.into(),
178 domain,
179 commutative,
180 })
181 }
182
183 pub fn integer(&self, n: impl Into<rug::Integer>) -> ExprId {
184 self.intern(ExprData::Integer(BigInt(n.into())))
185 }
186
187 pub fn rational(
188 &self,
189 numer: impl Into<rug::Integer>,
190 denom: impl Into<rug::Integer>,
191 ) -> ExprId {
192 let r = rug::Rational::from((numer.into(), denom.into()));
193 self.intern(ExprData::Rational(BigRat(r)))
194 }
195
196 pub fn float(&self, value: f64, prec: u32) -> ExprId {
197 let f = rug::Float::with_val(prec, value);
198 self.intern(ExprData::Float(BigFloat { inner: f, prec }))
199 }
200
201 pub fn add(&self, mut args: Vec<ExprId>) -> ExprId {
206 args.sort_unstable();
211 self.intern(ExprData::Add(args))
212 }
213
214 pub fn mul(&self, mut args: Vec<ExprId>) -> ExprId {
215 let sort_ok = args
217 .iter()
218 .all(|&a| crate::kernel::expr_props::mult_tree_is_commutative(self, a));
219 if sort_ok {
220 args.sort_unstable();
221 }
222 self.intern(ExprData::Mul(args))
223 }
224
225 pub fn pow(&self, base: ExprId, exp: ExprId) -> ExprId {
226 self.intern(ExprData::Pow { base, exp })
227 }
228
229 pub fn func(&self, name: impl Into<String>, args: Vec<ExprId>) -> ExprId {
230 self.intern(ExprData::Func {
231 name: name.into(),
232 args,
233 })
234 }
235
236 pub fn piecewise(&self, branches: Vec<(ExprId, ExprId)>, default: ExprId) -> ExprId {
246 self.intern(ExprData::Piecewise { branches, default })
247 }
248
249 pub fn predicate(&self, kind: crate::kernel::expr::PredicateKind, args: Vec<ExprId>) -> ExprId {
251 self.intern(ExprData::Predicate { kind, args })
252 }
253
254 pub fn pred_lt(&self, a: ExprId, b: ExprId) -> ExprId {
256 self.predicate(crate::kernel::expr::PredicateKind::Lt, vec![a, b])
257 }
258 pub fn pred_le(&self, a: ExprId, b: ExprId) -> ExprId {
259 self.predicate(crate::kernel::expr::PredicateKind::Le, vec![a, b])
260 }
261 pub fn pred_gt(&self, a: ExprId, b: ExprId) -> ExprId {
262 self.predicate(crate::kernel::expr::PredicateKind::Gt, vec![a, b])
263 }
264 pub fn pred_ge(&self, a: ExprId, b: ExprId) -> ExprId {
265 self.predicate(crate::kernel::expr::PredicateKind::Ge, vec![a, b])
266 }
267 pub fn pred_eq(&self, a: ExprId, b: ExprId) -> ExprId {
268 self.predicate(crate::kernel::expr::PredicateKind::Eq, vec![a, b])
269 }
270 pub fn pred_ne(&self, a: ExprId, b: ExprId) -> ExprId {
271 self.predicate(crate::kernel::expr::PredicateKind::Ne, vec![a, b])
272 }
273 pub fn pred_and(&self, args: Vec<ExprId>) -> ExprId {
274 self.predicate(crate::kernel::expr::PredicateKind::And, args)
275 }
276 pub fn pred_or(&self, args: Vec<ExprId>) -> ExprId {
277 self.predicate(crate::kernel::expr::PredicateKind::Or, args)
278 }
279 pub fn pred_not(&self, a: ExprId) -> ExprId {
280 self.predicate(crate::kernel::expr::PredicateKind::Not, vec![a])
281 }
282 pub fn pred_true(&self) -> ExprId {
283 self.predicate(crate::kernel::expr::PredicateKind::True, vec![])
284 }
285 pub fn pred_false(&self) -> ExprId {
286 self.predicate(crate::kernel::expr::PredicateKind::False, vec![])
287 }
288
289 pub fn forall(&self, var: ExprId, body: ExprId) -> ExprId {
292 self.intern(ExprData::Forall { var, body })
293 }
294
295 pub fn exists(&self, var: ExprId, body: ExprId) -> ExprId {
297 self.intern(ExprData::Exists { var, body })
298 }
299
300 pub fn big_o(&self, arg: ExprId) -> ExprId {
302 self.intern(ExprData::BigO(arg))
303 }
304
305 pub fn pos_infinity(&self) -> ExprId {
307 self.symbol(POS_INFINITY_SYMBOL, Domain::Positive)
308 }
309
310 pub fn display(&self, id: ExprId) -> ExprDisplay<'_> {
315 ExprDisplay { id, pool: self }
316 }
317}
318
319impl Default for ExprPool {
320 fn default() -> Self {
321 Self::new()
322 }
323}
324
325pub struct ExprDisplay<'a> {
331 pub id: ExprId,
332 pub pool: &'a ExprPool,
333}
334
335impl fmt::Display for ExprDisplay<'_> {
336 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
337 let data = self.pool.get(self.id);
338 fmt_data(&data, self.pool, f)
339 }
340}
341
342impl fmt::Debug for ExprDisplay<'_> {
343 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344 write!(f, "{}", self)
345 }
346}
347
348fn fmt_data(data: &ExprData, pool: &ExprPool, f: &mut fmt::Formatter<'_>) -> fmt::Result {
349 match data {
350 ExprData::Symbol { name, .. } => write!(f, "{}", name),
351 ExprData::Integer(n) => write!(f, "{}", n),
352 ExprData::Rational(r) => write!(f, "{}", r),
353 ExprData::Float(fl) => write!(f, "{}", fl),
354 ExprData::Add(args) => {
355 write!(f, "(")?;
356 for (i, &arg) in args.iter().enumerate() {
357 if i > 0 {
358 write!(f, " + ")?;
359 }
360 write!(f, "{}", pool.display(arg))?;
361 }
362 write!(f, ")")
363 }
364 ExprData::Mul(args) => {
365 write!(f, "(")?;
366 for (i, &arg) in args.iter().enumerate() {
367 if i > 0 {
368 write!(f, " * ")?;
369 }
370 write!(f, "{}", pool.display(arg))?;
371 }
372 write!(f, ")")
373 }
374 ExprData::Pow { base, exp } => {
375 write!(f, "{}^{}", pool.display(*base), pool.display(*exp))
376 }
377 ExprData::Func { name, args } => {
378 write!(f, "{}(", name)?;
379 for (i, &arg) in args.iter().enumerate() {
380 if i > 0 {
381 write!(f, ", ")?;
382 }
383 write!(f, "{}", pool.display(arg))?;
384 }
385 write!(f, ")")
386 }
387 ExprData::Piecewise { branches, default } => {
388 write!(f, "Piecewise(")?;
389 for (i, (cond, val)) in branches.iter().enumerate() {
390 if i > 0 {
391 write!(f, ", ")?;
392 }
393 write!(f, "({}, {})", pool.display(*cond), pool.display(*val))?;
394 }
395 write!(f, "; default={})", pool.display(*default))
396 }
397 ExprData::Predicate { kind, args } => match kind {
398 crate::kernel::expr::PredicateKind::True => write!(f, "True"),
399 crate::kernel::expr::PredicateKind::False => write!(f, "False"),
400 crate::kernel::expr::PredicateKind::Not => {
401 write!(f, "¬({})", pool.display(args[0]))
402 }
403 crate::kernel::expr::PredicateKind::And | crate::kernel::expr::PredicateKind::Or => {
404 write!(f, "(")?;
405 for (i, &arg) in args.iter().enumerate() {
406 if i > 0 {
407 write!(f, " {} ", kind)?;
408 }
409 write!(f, "{}", pool.display(arg))?;
410 }
411 write!(f, ")")
412 }
413 _ => {
414 write!(
415 f,
416 "({} {} {})",
417 pool.display(args[0]),
418 kind,
419 pool.display(args[1])
420 )
421 }
422 },
423 ExprData::Forall { var, body } => {
424 write!(f, "∀ {} . {}", pool.display(*var), pool.display(*body))
425 }
426 ExprData::Exists { var, body } => {
427 write!(f, "∃ {} . {}", pool.display(*var), pool.display(*body))
428 }
429 ExprData::BigO(arg) => {
430 write!(f, "O({})", pool.display(*arg))
431 }
432 }
433}
434
435#[cfg(test)]
440mod tests {
441 use super::*;
442 use crate::kernel::domain::Domain;
443
444 fn pool() -> ExprPool {
445 ExprPool::new()
446 }
447
448 #[test]
449 fn noncommutative_mul_orders_distinct() {
450 let p = pool();
451 let a = p.symbol_commutative("A", Domain::Real, false);
452 let b = p.symbol_commutative("B", Domain::Real, false);
453 assert_ne!(
454 p.mul(vec![a, b]),
455 p.mul(vec![b, a]),
456 "A*B and B*A must not hash-cons together for NC symbols"
457 );
458 }
459
460 #[test]
461 fn symbol_commutative_is_structural() {
462 let p = pool();
463 let xc = p.symbol_commutative("x", Domain::Real, true);
464 let xnc = p.symbol_commutative("x", Domain::Real, false);
465 assert_ne!(xc, xnc);
466 }
467
468 #[test]
471 fn symbol_interning() {
472 let p = pool();
473 let x1 = p.symbol("x", Domain::Real);
474 let x2 = p.symbol("x", Domain::Real);
475 assert_eq!(x1, x2, "same symbol must return same ExprId");
476 }
477
478 #[test]
479 fn domain_is_structural() {
480 let p = pool();
481 let xr = p.symbol("x", Domain::Real);
482 let xc = p.symbol("x", Domain::Complex);
483 assert_ne!(xr, xc, "same name but different domain must be distinct");
484 }
485
486 #[test]
487 fn integer_interning() {
488 let p = pool();
489 let a = p.integer(42_i32);
490 let b = p.integer(42_i32);
491 let c = p.integer(99_i32);
492 assert_eq!(a, b);
493 assert_ne!(a, c);
494 }
495
496 #[test]
497 fn rational_canonical() {
498 let p = pool();
499 let r1 = p.rational(2_i32, 4_i32);
501 let r2 = p.rational(1_i32, 2_i32);
502 assert_eq!(r1, r2, "rationals must be reduced to canonical form");
503 }
504
505 #[test]
506 fn float_precision_is_structural() {
507 let p = pool();
508 let f53 = p.float(1.0, 53);
509 let f64_ = p.float(1.0, 64);
510 assert_ne!(
511 f53, f64_,
512 "same value but different precision is a different expr"
513 );
514 }
515
516 #[test]
519 fn subexpression_sharing() {
520 let p = pool();
521 let x = p.symbol("x", Domain::Real);
522 let two = p.integer(2_i32);
523
524 let xsq1 = p.pow(x, two);
526 let xsq2 = p.pow(x, two);
527 assert_eq!(xsq1, xsq2);
528
529 assert_eq!(p.len(), 3);
531 }
532
533 #[test]
534 fn add_interning() {
535 let p = pool();
536 let x = p.symbol("x", Domain::Real);
537 let y = p.symbol("y", Domain::Real);
538 let s1 = p.add(vec![x, y]);
539 let s2 = p.add(vec![x, y]);
540 assert_eq!(s1, s2);
541 }
542
543 #[test]
544 fn arg_order_is_canonical() {
545 let p = pool();
548 let x = p.symbol("x", Domain::Real);
549 let y = p.symbol("y", Domain::Real);
550 let s1 = p.add(vec![x, y]);
551 let s2 = p.add(vec![y, x]);
552 assert_eq!(s1, s2, "a+b and b+a must be the same expression after PA-3");
553 let m1 = p.mul(vec![x, y]);
554 let m2 = p.mul(vec![y, x]);
555 assert_eq!(m1, m2, "a*b and b*a must be the same expression after PA-3");
556 }
557
558 #[test]
559 fn func_interning() {
560 let p = pool();
561 let x = p.symbol("x", Domain::Real);
562 let s1 = p.func("sin", vec![x]);
563 let s2 = p.func("sin", vec![x]);
564 let c1 = p.func("cos", vec![x]);
565 assert_eq!(s1, s2);
566 assert_ne!(s1, c1);
567 }
568
569 #[test]
572 fn display_symbol() {
573 let p = pool();
574 let x = p.symbol("x", Domain::Real);
575 assert_eq!(p.display(x).to_string(), "x");
576 }
577
578 #[test]
579 fn display_integer() {
580 let p = pool();
581 let n = p.integer(42_i32);
582 assert_eq!(p.display(n).to_string(), "42");
583 }
584
585 #[test]
586 fn display_pow() {
587 let p = pool();
588 let x = p.symbol("x", Domain::Real);
589 let two = p.integer(2_i32);
590 let xsq = p.pow(x, two);
591 assert_eq!(p.display(xsq).to_string(), "x^2");
592 }
593
594 #[test]
595 fn display_add() {
596 let p = pool();
597 let x = p.symbol("x", Domain::Real);
598 let y = p.symbol("y", Domain::Real);
599 let s = p.add(vec![x, y]);
600 assert_eq!(p.display(s).to_string(), "(x + y)");
601 }
602
603 #[test]
604 fn display_func() {
605 let p = pool();
606 let x = p.symbol("x", Domain::Real);
607 let s = p.func("sin", vec![x]);
608 assert_eq!(p.display(s).to_string(), "sin(x)");
609 }
610
611 #[test]
612 fn display_nested() {
613 let p = pool();
614 let x = p.symbol("x", Domain::Real);
615 let two = p.integer(2_i32);
616 let xsq = p.pow(x, two);
617 let one = p.integer(1_i32);
618 let expr = p.add(vec![xsq, one]);
619 assert_eq!(p.display(expr).to_string(), "(x^2 + 1)");
620 }
621
622 fn assert_send_sync<T: Send + Sync>() {}
625
626 #[test]
627 fn pool_is_send_sync() {
628 assert_send_sync::<ExprPool>();
629 }
630}