#![allow(deprecated)]
use std::cmp::Ordering;
use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::convert::AsRef;
use std::fmt::Debug;
use std::fmt::Write;
use std::hash::Hash;
use std::hash::Hasher;
use std::sync::Arc;
use std::sync::LazyLock;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering as AtomicOrdering;
use dashmap::DashMap;
use num_bigint::BigInt;
use num_rational::BigRational;
use ordered_float::OrderedFloat;
use super::expr::Expr;
use super::expr::PathType;
use super::expr::SparsePolynomial;
pub static DAG_MANAGER: LazyLock<DagManager> = LazyLock::new(DagManager::new);
#[derive(Debug, Clone, serde::Serialize)]
pub struct DagNode {
pub op: DagOp,
pub children: Vec<Arc<Self>>,
#[serde(skip)]
pub hash: u64,
#[serde(skip)]
pub id: u64,
}
impl<'de> serde::Deserialize<'de> for DagNode {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(serde::Deserialize)]
struct DagNodeHelper {
op: DagOp,
children: Vec<Arc<DagNode>>,
}
let helper = DagNodeHelper::deserialize(deserializer)?;
let mut hasher = std::collections::hash_map::DefaultHasher::new();
helper.op.hash(&mut hasher);
for child in &helper.children {
child.hash.hash(&mut hasher);
}
let hash = hasher.finish();
let id = DAG_MANAGER.next_id.fetch_add(1, AtomicOrdering::Relaxed);
Ok(Self {
op: helper.op,
children: helper.children,
hash,
id,
})
}
}
#[derive(
Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize,
)]
pub enum DagOp {
Constant(OrderedFloat<f64>),
BigInt(BigInt),
Rational(BigRational),
Boolean(bool),
Variable(String),
Pattern(String),
Domain(String),
Pi,
E,
Infinity,
NegativeInfinity,
InfiniteSolutions,
NoSolution,
Derivative(String),
DerivativeN(String),
Limit(String),
Solve(String),
ConvergenceAnalysis(String),
ForAll(String),
Exists(String),
Substitute(String),
Ode {
func: String,
var: String,
},
Pde {
func: String,
vars: Vec<String>,
},
Predicate {
name: String,
},
Path(PathType),
Interval(bool, bool),
Add,
Sub,
Mul,
Div,
Neg,
Power,
Sin,
Cos,
Tan,
Exp,
Log,
Abs,
Sqrt,
Eq,
Lt,
Gt,
Le,
Ge,
Matrix {
rows: usize,
cols: usize,
},
Vector,
Complex,
Transpose,
MatrixMul,
MatrixVecMul,
Inverse,
Integral,
VolumeIntegral,
SurfaceIntegral,
Sum,
Series(String),
Summation(String),
Product(String),
AsymptoticExpansion(String),
Sec,
Csc,
Cot,
ArcSin,
ArcCos,
ArcTan,
ArcSec,
ArcCsc,
ArcCot,
Sinh,
Cosh,
Tanh,
Sech,
Csch,
Coth,
ArcSinh,
ArcCosh,
ArcTanh,
ArcSech,
ArcCsch,
ArcCoth,
LogBase,
Atan2,
Binomial,
Factorial,
Permutation,
Combination,
FallingFactorial,
RisingFactorial,
Boundary,
Gamma,
Beta,
Erf,
Erfc,
Erfi,
Zeta,
BesselJ,
BesselY,
LegendreP,
LaguerreL,
HermiteH,
Digamma,
KroneckerDelta,
And,
Or,
Not,
Xor,
Implies,
Equivalent,
Union,
Polynomial,
SparsePolynomial(SparsePolynomial), Floor,
IsPrime,
Gcd,
Mod,
System,
Solutions,
ParametricSolution,
RootOf {
index: u32,
},
GeneralSolution,
ParticularSolution,
Fredholm,
Volterra,
Apply,
Tuple,
Distribution, Max,
Quantity, QuantityWithValue(String),
CustomZero,
CustomString(String),
CustomArcOne,
CustomArcTwo,
CustomArcThree,
CustomArcFour,
CustomArcFive,
CustomVecOne,
CustomVecTwo,
CustomVecThree,
CustomVecFour,
CustomVecFive,
UnaryList(String),
BinaryList(String),
NaryList(String),
}
impl PartialEq for DagNode {
fn eq(
&self,
other: &Self,
) -> bool {
if self.id != 0 && other.id != 0 {
return self.id == other.id;
}
if self.op != other.op || self.children.len() != other.children.len() {
return false;
}
let mut stack = Vec::with_capacity(16);
for (l, r) in self.children.iter().zip(other.children.iter()) {
stack.push((l.as_ref(), r.as_ref()));
}
while let Some((l_node, r_node)) = stack.pop() {
if std::ptr::eq(l_node, r_node) {
continue;
}
if l_node.id != 0 && r_node.id != 0 {
if l_node.id != r_node.id {
return false;
}
continue;
}
if l_node.op != r_node.op || l_node.children.len() != r_node.children.len() {
return false;
}
for (lc, rc) in l_node.children.iter().zip(r_node.children.iter()) {
stack.push((lc, rc));
}
}
true
}
}
impl Eq for DagNode {}
impl PartialOrd for DagNode {
fn partial_cmp(
&self,
other: &Self,
) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for DagNode {
fn cmp(
&self,
other: &Self,
) -> Ordering {
let mut stack = Vec::with_capacity(16);
stack.push((self, other));
while let Some((l, r)) = stack.pop() {
if std::ptr::eq(l, r) || (l.id != 0 && r.id != 0 && l.id == r.id) {
continue;
}
match l.op.cmp(&r.op) {
| Ordering::Equal => {
match l.children.len().cmp(&r.children.len()) {
| Ordering::Equal => {
for (lc, rc) in l.children.iter().zip(r.children.iter()).rev() {
stack.push((lc.as_ref(), rc.as_ref()));
}
},
| ord => return ord,
}
},
| ord => return ord,
}
}
Ordering::Equal
}
}
impl Hash for DagNode {
fn hash<H: Hasher>(
&self,
state: &mut H,
) {
if self.id != 0 {
state.write_u64(self.id);
} else {
self.op.hash(state);
state.write_u64(self.hash);
}
}
}
impl From<DagNode> for Expr {
fn from(node: DagNode) -> Self {
node.to_expr().expect(
"Cannot convert \
DagNode to Expr.",
)
}
}
pub struct DagManager {
nodes: DashMap<u64, Vec<Arc<DagNode>>>,
next_id: AtomicU64,
}
impl Default for DagManager {
fn default() -> Self {
Self::new()
}
}
impl DagManager {
#[inline]
#[must_use]
pub fn new() -> Self {
Self {
nodes: DashMap::new(),
next_id: AtomicU64::new(1),
}
}
#[inline]
pub fn get_or_create_normalized(
&self,
op: DagOp,
mut children: Vec<Arc<DagNode>>,
) -> Result<Arc<DagNode>, String> {
const MAX_CHILDREN: usize = 10000;
const MAX_BUCKET_SIZE: usize = 1000;
if children.len() > MAX_CHILDREN {
return Err(format!(
"Too many children in \
node ({}), exceeds \
limit of {}",
children.len(),
MAX_CHILDREN
));
}
match op {
| DagOp::Add | DagOp::Mul => {
children.sort();
},
| _ => {},
}
let mut hasher = ahash::AHasher::default();
op.hash(&mut hasher);
for c in &children {
Self::c_hash_for_hasher(c, &mut hasher);
}
let hash = hasher.finish();
let mut entry = self.nodes.entry(hash).or_default();
let bucket = entry.value_mut();
if bucket.len() > MAX_BUCKET_SIZE {
let id = self.next_id.fetch_add(1, AtomicOrdering::Relaxed);
let node = Arc::new(DagNode {
op,
children,
hash,
id,
});
return Ok(node);
}
for cand in bucket.iter() {
if Self::dag_nodes_structurally_equal(cand, &op, &children) {
return Ok(cand.clone());
}
}
let id = self.next_id.fetch_add(1, AtomicOrdering::Relaxed);
let node = Arc::new(DagNode {
op,
children,
hash,
id,
});
bucket.push(node.clone());
Ok(node)
}
pub(crate) fn dag_nodes_structurally_equal(
cand: &Arc<DagNode>,
op: &DagOp,
children: &Vec<Arc<DagNode>>,
) -> bool {
if cand.hash != Self::compute_op_children_hash(op, children) {
return false;
}
if &cand.op != op {
return false;
}
if cand.children.len() != children.len() {
return false;
}
for (a, b) in cand.children.iter().zip(children.iter()) {
if a.hash != b.hash {
return false;
}
}
cand.children == *children
}
pub(crate) fn compute_op_children_hash(
op: &DagOp,
children: &Vec<Arc<DagNode>>,
) -> u64 {
let mut hasher = ahash::AHasher::default();
op.hash(&mut hasher);
for c in children {
Self::c_hash_for_hasher(c, &mut hasher);
}
hasher.finish()
}
pub(crate) fn c_hash_for_hasher(
c: &Arc<DagNode>,
hasher: &mut ahash::AHasher,
) {
hasher.write_u64(c.hash);
}
#[inline]
pub fn get_or_create(
&self,
expr: &Expr,
) -> Result<Arc<DagNode>, String> {
if let Expr::Dag(node) = expr {
return Ok(node.clone());
}
let mut stack = vec![(expr.clone(), false)];
let mut result_stack: Vec<Arc<DagNode>> = Vec::new();
while let Some((curr_expr, visited)) = stack.pop() {
if let Expr::Dag(node) = &curr_expr {
result_stack.push(node.clone());
continue;
}
if visited {
let op = curr_expr.to_dag_op_internal()?;
let children_exprs = curr_expr.get_children_internal();
let children_count = children_exprs.len();
let mut children_nodes = Vec::with_capacity(children_count);
for _ in 0..children_count {
children_nodes.push(result_stack.pop().ok_or("Result stack underflow")?);
}
children_nodes.reverse();
let node = self.get_or_create_normalized(op, children_nodes)?;
result_stack.push(node);
} else {
let children = curr_expr.get_children_internal();
if children.is_empty() {
let op = curr_expr.to_dag_op_internal()?;
let node = self.get_or_create_normalized(op, Vec::new())?;
result_stack.push(node);
} else {
stack.push((curr_expr, true));
for child in children.into_iter().rev() {
stack.push((child, false));
}
}
}
}
result_stack
.pop()
.ok_or_else(|| "Root node not found".to_string())
}
}