use crate::kernel::{
domain::Domain,
expr::{BigFloat, BigInt, BigRat, ExprData, ExprId},
};
use std::fmt;
pub const POS_INFINITY_SYMBOL: &str = "\u{221e}";
#[cfg(feature = "parallel")]
use dashmap::DashMap;
#[cfg(not(feature = "parallel"))]
use std::collections::HashMap;
#[cfg(not(feature = "parallel"))]
use std::sync::Mutex;
#[cfg(feature = "parallel")]
struct PoolIndex(DashMap<ExprData, ExprId>);
#[cfg(not(feature = "parallel"))]
struct PoolIndex(HashMap<ExprData, ExprId>);
#[cfg(feature = "parallel")]
impl PoolIndex {
fn new() -> Self {
PoolIndex(DashMap::new())
}
fn get(&self, data: &ExprData) -> Option<ExprId> {
self.0.get(data).map(|v| *v)
}
fn or_insert_with(&self, key: ExprData, f: impl FnOnce() -> ExprId) -> ExprId {
*self.0.entry(key).or_insert_with(f)
}
}
#[cfg(not(feature = "parallel"))]
impl PoolIndex {
fn new() -> Self {
PoolIndex(HashMap::new())
}
fn get(&self, data: &ExprData) -> Option<ExprId> {
self.0.get(data).copied()
}
fn insert(&mut self, data: ExprData, id: ExprId) {
self.0.insert(data, id);
}
}
pub struct ExprPool {
nodes: boxcar::Vec<ExprData>,
#[cfg(feature = "parallel")]
index: PoolIndex,
#[cfg(not(feature = "parallel"))]
index: Mutex<PoolIndex>,
}
unsafe impl Send for ExprPool {}
unsafe impl Sync for ExprPool {}
impl ExprPool {
pub fn new() -> Self {
ExprPool {
nodes: boxcar::Vec::new(),
#[cfg(feature = "parallel")]
index: PoolIndex::new(),
#[cfg(not(feature = "parallel"))]
index: Mutex::new(PoolIndex::new()),
}
}
pub fn intern(&self, data: ExprData) -> ExprId {
#[cfg(feature = "parallel")]
{
if let Some(id) = self.index.get(&data) {
return id;
}
self.index
.or_insert_with(data.clone(), || ExprId(self.nodes.push(data) as u32))
}
#[cfg(not(feature = "parallel"))]
{
let mut idx = self.index.lock().expect("ExprPool index Mutex poisoned");
if let Some(id) = idx.get(&data) {
return id;
}
let id = ExprId(self.nodes.push(data.clone()) as u32);
idx.insert(data, id);
id
}
}
pub fn with<R, F: FnOnce(&ExprData) -> R>(&self, id: ExprId, f: F) -> R {
f(self
.nodes
.get(id.0 as usize)
.expect("ExprPool: ExprId out of range"))
}
pub fn get(&self, id: ExprId) -> ExprData {
self.with(id, |d| d.clone())
}
pub fn len(&self) -> usize {
self.nodes.count()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn symbol(&self, name: impl Into<String>, domain: Domain) -> ExprId {
self.symbol_commutative(name, domain, true)
}
pub fn symbol_commutative(
&self,
name: impl Into<String>,
domain: Domain,
commutative: bool,
) -> ExprId {
self.intern(ExprData::Symbol {
name: name.into(),
domain,
commutative,
})
}
pub fn integer(&self, n: impl Into<rug::Integer>) -> ExprId {
self.intern(ExprData::Integer(BigInt(n.into())))
}
pub fn rational(
&self,
numer: impl Into<rug::Integer>,
denom: impl Into<rug::Integer>,
) -> ExprId {
let r = rug::Rational::from((numer.into(), denom.into()));
self.intern(ExprData::Rational(BigRat(r)))
}
pub fn float(&self, value: f64, prec: u32) -> ExprId {
let f = rug::Float::with_val(prec, value);
self.intern(ExprData::Float(BigFloat { inner: f, prec }))
}
pub fn add(&self, mut args: Vec<ExprId>) -> ExprId {
args.sort_unstable();
self.intern(ExprData::Add(args))
}
pub fn mul(&self, mut args: Vec<ExprId>) -> ExprId {
let sort_ok = args
.iter()
.all(|&a| crate::kernel::expr_props::mult_tree_is_commutative(self, a));
if sort_ok {
args.sort_unstable();
}
self.intern(ExprData::Mul(args))
}
pub fn pow(&self, base: ExprId, exp: ExprId) -> ExprId {
self.intern(ExprData::Pow { base, exp })
}
pub fn func(&self, name: impl Into<String>, args: Vec<ExprId>) -> ExprId {
self.intern(ExprData::Func {
name: name.into(),
args,
})
}
pub fn piecewise(&self, branches: Vec<(ExprId, ExprId)>, default: ExprId) -> ExprId {
self.intern(ExprData::Piecewise { branches, default })
}
pub fn predicate(&self, kind: crate::kernel::expr::PredicateKind, args: Vec<ExprId>) -> ExprId {
self.intern(ExprData::Predicate { kind, args })
}
pub fn pred_lt(&self, a: ExprId, b: ExprId) -> ExprId {
self.predicate(crate::kernel::expr::PredicateKind::Lt, vec![a, b])
}
pub fn pred_le(&self, a: ExprId, b: ExprId) -> ExprId {
self.predicate(crate::kernel::expr::PredicateKind::Le, vec![a, b])
}
pub fn pred_gt(&self, a: ExprId, b: ExprId) -> ExprId {
self.predicate(crate::kernel::expr::PredicateKind::Gt, vec![a, b])
}
pub fn pred_ge(&self, a: ExprId, b: ExprId) -> ExprId {
self.predicate(crate::kernel::expr::PredicateKind::Ge, vec![a, b])
}
pub fn pred_eq(&self, a: ExprId, b: ExprId) -> ExprId {
self.predicate(crate::kernel::expr::PredicateKind::Eq, vec![a, b])
}
pub fn pred_ne(&self, a: ExprId, b: ExprId) -> ExprId {
self.predicate(crate::kernel::expr::PredicateKind::Ne, vec![a, b])
}
pub fn pred_and(&self, args: Vec<ExprId>) -> ExprId {
self.predicate(crate::kernel::expr::PredicateKind::And, args)
}
pub fn pred_or(&self, args: Vec<ExprId>) -> ExprId {
self.predicate(crate::kernel::expr::PredicateKind::Or, args)
}
pub fn pred_not(&self, a: ExprId) -> ExprId {
self.predicate(crate::kernel::expr::PredicateKind::Not, vec![a])
}
pub fn pred_true(&self) -> ExprId {
self.predicate(crate::kernel::expr::PredicateKind::True, vec![])
}
pub fn pred_false(&self) -> ExprId {
self.predicate(crate::kernel::expr::PredicateKind::False, vec![])
}
pub fn forall(&self, var: ExprId, body: ExprId) -> ExprId {
self.intern(ExprData::Forall { var, body })
}
pub fn exists(&self, var: ExprId, body: ExprId) -> ExprId {
self.intern(ExprData::Exists { var, body })
}
pub fn big_o(&self, arg: ExprId) -> ExprId {
self.intern(ExprData::BigO(arg))
}
pub fn pos_infinity(&self) -> ExprId {
self.symbol(POS_INFINITY_SYMBOL, Domain::Positive)
}
pub fn display(&self, id: ExprId) -> ExprDisplay<'_> {
ExprDisplay { id, pool: self }
}
}
impl Default for ExprPool {
fn default() -> Self {
Self::new()
}
}
pub struct ExprDisplay<'a> {
pub id: ExprId,
pub pool: &'a ExprPool,
}
impl fmt::Display for ExprDisplay<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let data = self.pool.get(self.id);
fmt_data(&data, self.pool, f)
}
}
impl fmt::Debug for ExprDisplay<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self)
}
}
fn fmt_data(data: &ExprData, pool: &ExprPool, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match data {
ExprData::Symbol { name, .. } => write!(f, "{}", name),
ExprData::Integer(n) => write!(f, "{}", n),
ExprData::Rational(r) => write!(f, "{}", r),
ExprData::Float(fl) => write!(f, "{}", fl),
ExprData::Add(args) => {
write!(f, "(")?;
for (i, &arg) in args.iter().enumerate() {
if i > 0 {
write!(f, " + ")?;
}
write!(f, "{}", pool.display(arg))?;
}
write!(f, ")")
}
ExprData::Mul(args) => {
write!(f, "(")?;
for (i, &arg) in args.iter().enumerate() {
if i > 0 {
write!(f, " * ")?;
}
write!(f, "{}", pool.display(arg))?;
}
write!(f, ")")
}
ExprData::Pow { base, exp } => {
write!(f, "{}^{}", pool.display(*base), pool.display(*exp))
}
ExprData::Func { name, args } => {
write!(f, "{}(", name)?;
for (i, &arg) in args.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", pool.display(arg))?;
}
write!(f, ")")
}
ExprData::Piecewise { branches, default } => {
write!(f, "Piecewise(")?;
for (i, (cond, val)) in branches.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "({}, {})", pool.display(*cond), pool.display(*val))?;
}
write!(f, "; default={})", pool.display(*default))
}
ExprData::Predicate { kind, args } => match kind {
crate::kernel::expr::PredicateKind::True => write!(f, "True"),
crate::kernel::expr::PredicateKind::False => write!(f, "False"),
crate::kernel::expr::PredicateKind::Not => {
write!(f, "¬({})", pool.display(args[0]))
}
crate::kernel::expr::PredicateKind::And | crate::kernel::expr::PredicateKind::Or => {
write!(f, "(")?;
for (i, &arg) in args.iter().enumerate() {
if i > 0 {
write!(f, " {} ", kind)?;
}
write!(f, "{}", pool.display(arg))?;
}
write!(f, ")")
}
_ => {
write!(
f,
"({} {} {})",
pool.display(args[0]),
kind,
pool.display(args[1])
)
}
},
ExprData::Forall { var, body } => {
write!(f, "∀ {} . {}", pool.display(*var), pool.display(*body))
}
ExprData::Exists { var, body } => {
write!(f, "∃ {} . {}", pool.display(*var), pool.display(*body))
}
ExprData::BigO(arg) => {
write!(f, "O({})", pool.display(*arg))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kernel::domain::Domain;
fn pool() -> ExprPool {
ExprPool::new()
}
#[test]
fn noncommutative_mul_orders_distinct() {
let p = pool();
let a = p.symbol_commutative("A", Domain::Real, false);
let b = p.symbol_commutative("B", Domain::Real, false);
assert_ne!(
p.mul(vec![a, b]),
p.mul(vec![b, a]),
"A*B and B*A must not hash-cons together for NC symbols"
);
}
#[test]
fn symbol_commutative_is_structural() {
let p = pool();
let xc = p.symbol_commutative("x", Domain::Real, true);
let xnc = p.symbol_commutative("x", Domain::Real, false);
assert_ne!(xc, xnc);
}
#[test]
fn symbol_interning() {
let p = pool();
let x1 = p.symbol("x", Domain::Real);
let x2 = p.symbol("x", Domain::Real);
assert_eq!(x1, x2, "same symbol must return same ExprId");
}
#[test]
fn domain_is_structural() {
let p = pool();
let xr = p.symbol("x", Domain::Real);
let xc = p.symbol("x", Domain::Complex);
assert_ne!(xr, xc, "same name but different domain must be distinct");
}
#[test]
fn integer_interning() {
let p = pool();
let a = p.integer(42_i32);
let b = p.integer(42_i32);
let c = p.integer(99_i32);
assert_eq!(a, b);
assert_ne!(a, c);
}
#[test]
fn rational_canonical() {
let p = pool();
let r1 = p.rational(2_i32, 4_i32);
let r2 = p.rational(1_i32, 2_i32);
assert_eq!(r1, r2, "rationals must be reduced to canonical form");
}
#[test]
fn float_precision_is_structural() {
let p = pool();
let f53 = p.float(1.0, 53);
let f64_ = p.float(1.0, 64);
assert_ne!(
f53, f64_,
"same value but different precision is a different expr"
);
}
#[test]
fn subexpression_sharing() {
let p = pool();
let x = p.symbol("x", Domain::Real);
let two = p.integer(2_i32);
let xsq1 = p.pow(x, two);
let xsq2 = p.pow(x, two);
assert_eq!(xsq1, xsq2);
assert_eq!(p.len(), 3);
}
#[test]
fn add_interning() {
let p = pool();
let x = p.symbol("x", Domain::Real);
let y = p.symbol("y", Domain::Real);
let s1 = p.add(vec![x, y]);
let s2 = p.add(vec![x, y]);
assert_eq!(s1, s2);
}
#[test]
fn arg_order_is_canonical() {
let p = pool();
let x = p.symbol("x", Domain::Real);
let y = p.symbol("y", Domain::Real);
let s1 = p.add(vec![x, y]);
let s2 = p.add(vec![y, x]);
assert_eq!(s1, s2, "a+b and b+a must be the same expression after PA-3");
let m1 = p.mul(vec![x, y]);
let m2 = p.mul(vec![y, x]);
assert_eq!(m1, m2, "a*b and b*a must be the same expression after PA-3");
}
#[test]
fn func_interning() {
let p = pool();
let x = p.symbol("x", Domain::Real);
let s1 = p.func("sin", vec![x]);
let s2 = p.func("sin", vec![x]);
let c1 = p.func("cos", vec![x]);
assert_eq!(s1, s2);
assert_ne!(s1, c1);
}
#[test]
fn display_symbol() {
let p = pool();
let x = p.symbol("x", Domain::Real);
assert_eq!(p.display(x).to_string(), "x");
}
#[test]
fn display_integer() {
let p = pool();
let n = p.integer(42_i32);
assert_eq!(p.display(n).to_string(), "42");
}
#[test]
fn display_pow() {
let p = pool();
let x = p.symbol("x", Domain::Real);
let two = p.integer(2_i32);
let xsq = p.pow(x, two);
assert_eq!(p.display(xsq).to_string(), "x^2");
}
#[test]
fn display_add() {
let p = pool();
let x = p.symbol("x", Domain::Real);
let y = p.symbol("y", Domain::Real);
let s = p.add(vec![x, y]);
assert_eq!(p.display(s).to_string(), "(x + y)");
}
#[test]
fn display_func() {
let p = pool();
let x = p.symbol("x", Domain::Real);
let s = p.func("sin", vec![x]);
assert_eq!(p.display(s).to_string(), "sin(x)");
}
#[test]
fn display_nested() {
let p = pool();
let x = p.symbol("x", Domain::Real);
let two = p.integer(2_i32);
let xsq = p.pow(x, two);
let one = p.integer(1_i32);
let expr = p.add(vec![xsq, one]);
assert_eq!(p.display(expr).to_string(), "(x^2 + 1)");
}
fn assert_send_sync<T: Send + Sync>() {}
#[test]
fn pool_is_send_sync() {
assert_send_sync::<ExprPool>();
}
}