use std::ops::{BitOr, Index, IndexMut};
use std::{cmp::Ordering, convert::TryFrom};
use std::{
convert::Infallible,
fmt::{self, Debug, Display},
};
use std::{hash::Hash, str::FromStr};
use crate::*;
use fmt::Formatter;
use symbolic_expressions::{Sexp, SexpError};
use thiserror::Error;
#[allow(clippy::len_without_is_empty)]
pub trait Language: Debug + Clone + Eq + Ord + Hash {
fn matches(&self, other: &Self) -> bool;
fn children(&self) -> &[Id];
fn children_mut(&mut self) -> &mut [Id];
fn for_each<F: FnMut(Id)>(&self, f: F) {
self.children().iter().copied().for_each(f)
}
fn for_each_mut<F: FnMut(&mut Id)>(&mut self, f: F) {
self.children_mut().iter_mut().for_each(f)
}
fn try_for_each<E, F>(&self, mut f: F) -> Result<(), E>
where
F: FnMut(Id) -> Result<(), E>,
E: Clone,
{
self.fold(Ok(()), |res, id| res.and_then(|_| f(id)))
}
fn len(&self) -> usize {
self.fold(0, |len, _| len + 1)
}
fn is_leaf(&self) -> bool {
self.all(|_| false)
}
fn update_children<F: FnMut(Id) -> Id>(&mut self, mut f: F) {
self.for_each_mut(|id| *id = f(*id))
}
fn map_children<F: FnMut(Id) -> Id>(mut self, f: F) -> Self {
self.update_children(f);
self
}
fn fold<F, T>(&self, init: T, mut f: F) -> T
where
F: FnMut(T, Id) -> T,
T: Clone,
{
let mut acc = init;
self.for_each(|id| acc = f(acc.clone(), id));
acc
}
fn all<F: FnMut(Id) -> bool>(&self, mut f: F) -> bool {
self.fold(true, |acc, id| acc && f(id))
}
fn any<F: FnMut(Id) -> bool>(&self, mut f: F) -> bool {
self.fold(false, |acc, id| acc || f(id))
}
fn join_recexprs<F, Expr>(&self, mut child_recexpr: F) -> RecExpr<Self>
where
F: FnMut(Id) -> Expr,
Expr: AsRef<[Self]>,
{
fn build<L: Language>(to: &mut RecExpr<L>, from: &[L]) -> Id {
let last = from.last().unwrap().clone();
let new_node = last.map_children(|id| {
let i = usize::from(id) + 1;
build(to, &from[0..i])
});
to.add(new_node)
}
let mut expr = RecExpr::default();
let node = self
.clone()
.map_children(|id| build(&mut expr, child_recexpr(id).as_ref()));
expr.add(node);
expr
}
fn build_recexpr<F>(&self, mut get_node: F) -> RecExpr<Self>
where
F: FnMut(Id) -> Self,
{
self.try_build_recexpr::<_, std::convert::Infallible>(|id| Ok(get_node(id)))
.unwrap()
}
fn try_build_recexpr<F, Err>(&self, mut get_node: F) -> Result<RecExpr<Self>, Err>
where
F: FnMut(Id) -> Result<Self, Err>,
{
let mut set = IndexSet::<Self>::default();
let mut ids = HashMap::<Id, Id>::default();
let mut todo = self.children().to_vec();
while let Some(id) = todo.last().copied() {
if ids.contains_key(&id) {
todo.pop();
continue;
}
let node = get_node(id)?;
let mut ids_has_all_children = true;
for child in node.children() {
if !ids.contains_key(child) {
ids_has_all_children = false;
todo.push(*child)
}
}
if ids_has_all_children {
let node = node.map_children(|id| ids[&id]);
let new_id = set.insert_full(node).0;
ids.insert(id, Id::from(new_id));
todo.pop();
}
}
let mut nodes: Vec<Self> = set.into_iter().collect();
nodes.push(self.clone().map_children(|id| ids[&id]));
Ok(RecExpr::from(nodes))
}
}
pub trait FromOp: Language + Sized {
type Error: Debug;
fn from_op(op: &str, children: Vec<Id>) -> Result<Self, Self::Error>;
}
#[derive(Debug, Error)]
#[error("could not parse an e-node with operator {op:?} and children {children:?}")]
pub struct FromOpError {
op: String,
children: Vec<Id>,
}
impl FromOpError {
pub fn new(op: &str, children: Vec<Id>) -> Self {
Self {
op: op.to_owned(),
children,
}
}
}
pub trait LanguageChildren {
fn is_empty(&self) -> bool {
self.len() == 0
}
fn len(&self) -> usize;
fn can_be_length(n: usize) -> bool;
fn from_vec(v: Vec<Id>) -> Self;
fn as_slice(&self) -> &[Id];
fn as_mut_slice(&mut self) -> &mut [Id];
}
impl<const N: usize> LanguageChildren for [Id; N] {
fn len(&self) -> usize {
N
}
fn can_be_length(n: usize) -> bool {
n == N
}
fn from_vec(v: Vec<Id>) -> Self {
Self::try_from(v.as_slice()).unwrap()
}
fn as_slice(&self) -> &[Id] {
self
}
fn as_mut_slice(&mut self) -> &mut [Id] {
self
}
}
#[rustfmt::skip]
impl LanguageChildren for Box<[Id]> {
fn len(&self) -> usize { <[Id]>::len(self) }
fn can_be_length(_: usize) -> bool { true }
fn from_vec(v: Vec<Id>) -> Self { v.into() }
fn as_slice(&self) -> &[Id] { self }
fn as_mut_slice(&mut self) -> &mut [Id] { self }
}
#[rustfmt::skip]
impl LanguageChildren for Vec<Id> {
fn len(&self) -> usize { <[Id]>::len(self) }
fn can_be_length(_: usize) -> bool { true }
fn from_vec(v: Vec<Id>) -> Self { v }
fn as_slice(&self) -> &[Id] { self }
fn as_mut_slice(&mut self) -> &mut [Id] { self }
}
#[rustfmt::skip]
impl LanguageChildren for Id {
fn len(&self) -> usize { 1 }
fn can_be_length(n: usize) -> bool { n == 1 }
fn from_vec(v: Vec<Id>) -> Self { v[0] }
fn as_slice(&self) -> &[Id] { std::slice::from_ref(self) }
fn as_mut_slice(&mut self) -> &mut [Id] { std::slice::from_mut(self) }
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct RecExpr<L> {
nodes: Vec<L>,
}
#[cfg(feature = "serde-1")]
impl<L: Language + Display> serde::Serialize for RecExpr<L> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let s = self.to_sexp().to_string();
serializer.serialize_str(&s)
}
}
impl<L> Default for RecExpr<L> {
fn default() -> Self {
Self::from(vec![])
}
}
impl<L> AsRef<[L]> for RecExpr<L> {
fn as_ref(&self) -> &[L] {
&self.nodes
}
}
impl<L> From<Vec<L>> for RecExpr<L> {
fn from(nodes: Vec<L>) -> Self {
Self { nodes }
}
}
impl<L: Language> RecExpr<L> {
pub fn add(&mut self, node: L) -> Id {
debug_assert!(
node.all(|id| usize::from(id) < self.nodes.len()),
"node {:?} has children not in this expr: {:?}",
node,
self
);
self.nodes.push(node);
Id::from(self.nodes.len() - 1)
}
pub(crate) fn compact(mut self) -> Self {
let mut ids = HashMap::<Id, Id>::default();
let mut set = IndexSet::default();
for (i, node) in self.nodes.drain(..).enumerate() {
let node = node.map_children(|id| ids[&id]);
let new_id = set.insert_full(node).0;
ids.insert(Id::from(i), Id::from(new_id));
}
self.nodes.extend(set);
self
}
pub(crate) fn extract(&self, new_root: Id) -> Self {
self[new_root].build_recexpr(|id| self[id].clone())
}
pub fn is_dag(&self) -> bool {
for (i, n) in self.nodes.iter().enumerate() {
for &child in n.children() {
if usize::from(child) >= i {
return false;
}
}
}
true
}
}
impl<L: Language> Index<Id> for RecExpr<L> {
type Output = L;
fn index(&self, id: Id) -> &L {
&self.nodes[usize::from(id)]
}
}
impl<L: Language> IndexMut<Id> for RecExpr<L> {
fn index_mut(&mut self, id: Id) -> &mut L {
&mut self.nodes[usize::from(id)]
}
}
impl<L: Language + Display> Display for RecExpr<L> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.nodes.is_empty() {
Display::fmt("()", f)
} else {
let s = self.to_sexp().to_string();
Display::fmt(&s, f)
}
}
}
impl<L: Language + Display> RecExpr<L> {
pub(crate) fn to_sexp(&self) -> Sexp {
let last = self.nodes.len() - 1;
if !self.is_dag() {
log::warn!("Tried to print a non-dag: {:?}", self.nodes);
}
self.to_sexp_rec(last, &mut |_| None)
}
fn to_sexp_rec(&self, i: usize, f: &mut impl FnMut(usize) -> Option<String>) -> Sexp {
let node = &self.nodes[i];
let op = Sexp::String(node.to_string());
if node.is_leaf() {
op
} else {
let mut vec = vec![op];
for child in node.children().iter().map(|i| usize::from(*i)) {
vec.push(if let Some(s) = f(child) {
return Sexp::String(s);
} else if child < i {
self.to_sexp_rec(child, f)
} else {
Sexp::String(format!("<<<< CYCLE to {} = {:?} >>>>", i, node))
})
}
Sexp::List(vec)
}
}
pub fn pretty(&self, width: usize) -> String {
let sexp = self.to_sexp();
let mut buf = String::new();
pretty_print(&mut buf, &sexp, width, 1).unwrap();
buf
}
}
#[derive(Debug, Error)]
pub enum RecExprParseError<E> {
#[error("found empty s-expression")]
EmptySexp,
#[error("found a list in the head position: {0}")]
HeadList(Sexp),
#[error(transparent)]
BadOp(E),
#[error(transparent)]
BadSexp(SexpError),
}
impl<L: FromOp> FromStr for RecExpr<L> {
type Err = RecExprParseError<L::Error>;
fn from_str(s: &str) -> Result<Self, Self::Err> {
use RecExprParseError::*;
fn parse_sexp_into<L: FromOp>(
sexp: &Sexp,
expr: &mut RecExpr<L>,
) -> Result<Id, RecExprParseError<L::Error>> {
match sexp {
Sexp::Empty => Err(EmptySexp),
Sexp::String(s) => {
let node = L::from_op(s, vec![]).map_err(BadOp)?;
Ok(expr.add(node))
}
Sexp::List(list) if list.is_empty() => Err(EmptySexp),
Sexp::List(list) => match &list[0] {
Sexp::Empty => unreachable!("Cannot be in head position"),
list @ Sexp::List(..) => Err(HeadList(list.to_owned())),
Sexp::String(op) => {
let arg_ids: Vec<Id> = list[1..]
.iter()
.map(|s| parse_sexp_into(s, expr))
.collect::<Result<_, _>>()?;
let node = L::from_op(op, arg_ids).map_err(BadOp)?;
Ok(expr.add(node))
}
},
}
}
let mut expr = RecExpr::default();
let sexp = symbolic_expressions::parser::parse_str(s.trim()).map_err(BadSexp)?;
parse_sexp_into(&sexp, &mut expr)?;
Ok(expr)
}
}
pub struct DidMerge(pub bool, pub bool);
impl BitOr for DidMerge {
type Output = DidMerge;
fn bitor(mut self, rhs: Self) -> Self::Output {
self.0 |= rhs.0;
self.1 |= rhs.1;
self
}
}
pub trait Analysis<L: Language>: Sized {
type Data: Debug;
fn make(egraph: &EGraph<L, Self>, enode: &L) -> Self::Data;
#[allow(unused_variables)]
fn pre_union(
egraph: &EGraph<L, Self>,
id1: Id,
id2: Id,
justification: &Option<Justification>,
) {
}
fn merge(&mut self, a: &mut Self::Data, b: Self::Data) -> DidMerge;
#[allow(unused_variables)]
fn modify(egraph: &mut EGraph<L, Self>, id: Id) {}
}
impl<L: Language> Analysis<L> for () {
type Data = ();
fn make(_egraph: &EGraph<L, Self>, _enode: &L) -> Self::Data {}
fn merge(&mut self, _: &mut Self::Data, _: Self::Data) -> DidMerge {
DidMerge(false, false)
}
}
pub fn merge_max<T: Ord>(to: &mut T, from: T) -> DidMerge {
let cmp = (*to).cmp(&from);
match cmp {
Ordering::Less => {
*to = from;
DidMerge(true, false)
}
Ordering::Equal => DidMerge(false, false),
Ordering::Greater => DidMerge(false, true),
}
}
pub fn merge_min<T: Ord>(to: &mut T, from: T) -> DidMerge {
let cmp = (*to).cmp(&from);
match cmp {
Ordering::Less => DidMerge(false, true),
Ordering::Equal => DidMerge(false, false),
Ordering::Greater => {
*to = from;
DidMerge(true, false)
}
}
}
pub fn merge_option<T>(
to: &mut Option<T>,
from: Option<T>,
merge_fn: impl FnOnce(&mut T, T) -> DidMerge,
) -> DidMerge {
match (to.as_mut(), from) {
(None, None) => DidMerge(false, false),
(None, from @ Some(_)) => {
*to = from;
DidMerge(true, false)
}
(Some(_), None) => DidMerge(false, true),
(Some(a), Some(b)) => merge_fn(a, b),
}
}
#[derive(Debug, Hash, PartialEq, Eq, Clone, PartialOrd, Ord)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
pub struct SymbolLang {
pub op: Symbol,
pub children: Vec<Id>,
}
impl SymbolLang {
pub fn new(op: impl Into<Symbol>, children: Vec<Id>) -> Self {
let op = op.into();
Self { op, children }
}
pub fn leaf(op: impl Into<Symbol>) -> Self {
Self::new(op, vec![])
}
}
impl Language for SymbolLang {
fn matches(&self, other: &Self) -> bool {
self.op == other.op && self.len() == other.len()
}
fn children(&self) -> &[Id] {
&self.children
}
fn children_mut(&mut self) -> &mut [Id] {
&mut self.children
}
}
impl Display for SymbolLang {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&self.op, f)
}
}
impl FromOp for SymbolLang {
type Error = Infallible;
fn from_op(op: &str, children: Vec<Id>) -> Result<Self, Self::Error> {
Ok(Self {
op: op.into(),
children,
})
}
}