use std::convert::TryFrom;
use std::fmt::{self, Debug, Display};
use std::hash::Hash;
use crate::{EGraph, Id, Symbol};
use symbolic_expressions::Sexp;
#[allow(clippy::len_without_is_empty)]
pub trait Language: Debug + Clone + Eq + Ord + Hash {
fn matches(&self, other: &Self) -> bool;
fn children(&self) -> &[Id];
fn children_mut(&mut self) -> &mut [Id];
fn for_each<F: FnMut(Id)>(&self, f: F) {
self.children().iter().copied().for_each(f)
}
fn for_each_mut<F: FnMut(&mut Id)>(&mut self, f: F) {
self.children_mut().iter_mut().for_each(f)
}
fn display_op(&self) -> &dyn Display {
unimplemented!("display_op not implemented")
}
#[allow(unused_variables)]
fn from_op_str(op_str: &str, children: Vec<Id>) -> Result<Self, String> {
unimplemented!("from_op_str not implemented")
}
fn len(&self) -> usize {
self.children().len()
}
fn is_leaf(&self) -> bool {
self.children().is_empty()
}
fn update_children<F: FnMut(Id) -> Id>(&mut self, mut f: F) {
self.for_each_mut(|id| *id = f(*id))
}
fn map_children<F: FnMut(Id) -> Id>(mut self, f: F) -> Self {
self.update_children(f);
self
}
fn fold<F, T>(&self, init: T, mut f: F) -> T
where
F: FnMut(T, Id) -> T,
T: Clone,
{
let mut acc = init;
self.for_each(|id| acc = f(acc.clone(), id));
acc
}
fn to_recexpr<'a, F>(&self, mut child_recexpr: F) -> RecExpr<Self>
where
Self: 'a,
F: FnMut(Id) -> &'a [Self],
{
fn build<L: Language>(to: &mut RecExpr<L>, from: &[L]) -> Id {
let last = from.last().unwrap().clone();
let new_node = last.map_children(|id| {
let i = usize::from(id) + 1;
build(to, &from[0..i])
});
to.add(new_node)
}
let mut expr = RecExpr::default();
let node = self
.clone()
.map_children(|id| build(&mut expr, child_recexpr(id)));
expr.add(node);
expr
}
}
pub trait LanguageChildren {
fn is_empty(&self) -> bool {
self.len() == 0
}
fn len(&self) -> usize;
fn can_be_length(n: usize) -> bool;
fn from_vec(v: Vec<Id>) -> Self;
fn as_slice(&self) -> &[Id];
fn as_mut_slice(&mut self) -> &mut [Id];
}
macro_rules! impl_array {
() => {};
($n:literal, $($rest:tt)*) => {
impl LanguageChildren for [Id; $n] {
fn len(&self) -> usize { <[Id]>::len(self) }
fn can_be_length(n: usize) -> bool { n == $n }
fn from_vec(v: Vec<Id>) -> Self { Self::try_from(v.as_slice()).unwrap() }
fn as_slice(&self) -> &[Id] { self }
fn as_mut_slice(&mut self) -> &mut [Id] { self }
}
impl_array!($($rest)*);
};
}
impl_array!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,);
#[rustfmt::skip]
impl LanguageChildren for Box<[Id]> {
fn len(&self) -> usize { <[Id]>::len(self) }
fn can_be_length(_: usize) -> bool { true }
fn from_vec(v: Vec<Id>) -> Self { v.into() }
fn as_slice(&self) -> &[Id] { self }
fn as_mut_slice(&mut self) -> &mut [Id] { self }
}
#[rustfmt::skip]
impl LanguageChildren for Vec<Id> {
fn len(&self) -> usize { <[Id]>::len(self) }
fn can_be_length(_: usize) -> bool { true }
fn from_vec(v: Vec<Id>) -> Self { v }
fn as_slice(&self) -> &[Id] { self }
fn as_mut_slice(&mut self) -> &mut [Id] { self }
}
#[rustfmt::skip]
impl LanguageChildren for Id {
fn len(&self) -> usize { 1 }
fn can_be_length(n: usize) -> bool { n == 1 }
fn from_vec(v: Vec<Id>) -> Self { v[0] }
fn as_slice(&self) -> &[Id] { std::slice::from_ref(self) }
fn as_mut_slice(&mut self) -> &mut [Id] { std::slice::from_mut(self) }
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct RecExpr<L> {
nodes: Vec<L>,
}
#[cfg(feature = "serde-1")]
impl<L: Language> serde::Serialize for RecExpr<L> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let s = self.to_sexp(self.nodes.len() - 1).to_string();
serializer.serialize_str(&s)
}
}
impl<L> Default for RecExpr<L> {
fn default() -> Self {
Self::from(vec![])
}
}
impl<L> AsRef<[L]> for RecExpr<L> {
fn as_ref(&self) -> &[L] {
&self.nodes
}
}
impl<L> From<Vec<L>> for RecExpr<L> {
fn from(nodes: Vec<L>) -> Self {
Self { nodes }
}
}
impl<L: Language> RecExpr<L> {
pub fn add(&mut self, node: L) -> Id {
debug_assert!(
node.children()
.iter()
.all(|&id| usize::from(id) < self.nodes.len()),
"node {:?} has children not in this expr: {:?}",
node,
self
);
self.nodes.push(node);
Id::from(self.nodes.len() - 1)
}
}
impl<L: Language> Display for RecExpr<L> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.nodes.is_empty() {
write!(f, "()")
} else {
let s = self.to_sexp(self.nodes.len() - 1).to_string();
write!(f, "{}", s)
}
}
}
impl<L: Language> RecExpr<L> {
fn to_sexp(&self, i: usize) -> Sexp {
let node = &self.nodes[i];
let op = Sexp::String(node.display_op().to_string());
if node.is_leaf() {
op
} else {
let mut vec = vec![op];
node.for_each(|id| vec.push(self.to_sexp(id.into())));
Sexp::List(vec)
}
}
pub fn pretty(&self, width: usize) -> String {
use std::fmt::{Result, Write};
let sexp = self.to_sexp(self.nodes.len() - 1);
fn pp(buf: &mut String, sexp: &Sexp, width: usize, level: usize) -> Result {
if let Sexp::List(list) = sexp {
let indent = sexp.to_string().len() > width;
write!(buf, "(")?;
for (i, val) in list.iter().enumerate() {
if indent && i > 0 {
writeln!(buf)?;
for _ in 0..level {
write!(buf, " ")?;
}
}
pp(buf, val, width, level + 1)?;
if !indent && i < list.len() - 1 {
write!(buf, " ")?;
}
}
write!(buf, ")")?;
Ok(())
} else {
write!(buf, "{}", sexp.to_string().trim_matches('"'))
}
}
let mut buf = String::new();
pp(&mut buf, &sexp, width, 1).unwrap();
buf
}
}
macro_rules! bail {
($s:literal $(,)?) => {
return Err($s.into())
};
($s:literal, $($args:expr),+) => {
return Err(format!($s, $($args),+).into())
};
}
impl<L: Language> std::str::FromStr for RecExpr<L> {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
fn parse_sexp_into<L: Language>(sexp: &Sexp, expr: &mut RecExpr<L>) -> Result<Id, String> {
match sexp {
Sexp::Empty => Err("Found empty s-expression".into()),
Sexp::String(s) => {
let node = L::from_op_str(s, vec![])?;
Ok(expr.add(node))
}
Sexp::List(list) if list.is_empty() => Err("Found empty s-expression".into()),
Sexp::List(list) => match &list[0] {
Sexp::Empty => unreachable!("Cannot be in head position"),
Sexp::List(l) => bail!("Found a list in the head position: {:?}", l),
Sexp::String(op) => {
let arg_ids: Result<Vec<Id>, _> =
list[1..].iter().map(|s| parse_sexp_into(s, expr)).collect();
let node = L::from_op_str(op, arg_ids?).map_err(|e| {
format!("Failed to parse '{}', error message:\n{}", sexp, e)
})?;
Ok(expr.add(node))
}
},
}
}
let mut expr = RecExpr::default();
let sexp = symbolic_expressions::parser::parse_str(s.trim()).map_err(|e| e.to_string())?;
parse_sexp_into(&sexp, &mut expr)?;
Ok(expr)
}
}
pub trait Analysis<L: Language>: Sized {
type Data: Debug;
fn make(egraph: &EGraph<L, Self>, enode: &L) -> Self::Data;
fn merge(&self, to: &mut Self::Data, from: Self::Data) -> bool;
#[allow(unused_variables)]
fn modify(egraph: &mut EGraph<L, Self>, id: Id) {}
}
pub fn merge_if_different<D: PartialEq>(to: &mut D, new: D) -> bool {
if *to == new {
false
} else {
*to = new;
true
}
}
impl<L: Language> Analysis<L> for () {
type Data = ();
fn make(_egraph: &EGraph<L, Self>, _enode: &L) -> Self::Data {}
fn merge(&self, _to: &mut Self::Data, _from: Self::Data) -> bool {
false
}
}
#[derive(Debug, Hash, PartialEq, Eq, Clone, PartialOrd, Ord)]
pub struct SymbolLang {
pub op: Symbol,
pub children: Vec<Id>,
}
impl SymbolLang {
pub fn new(op: impl Into<Symbol>, children: Vec<Id>) -> Self {
let op = op.into();
Self { op, children }
}
pub fn leaf(op: impl Into<Symbol>) -> Self {
Self::new(op, vec![])
}
}
impl Language for SymbolLang {
fn matches(&self, other: &Self) -> bool {
self.op == other.op && self.len() == other.len()
}
fn children(&self) -> &[Id] {
&self.children
}
fn children_mut(&mut self) -> &mut [Id] {
&mut self.children
}
fn display_op(&self) -> &dyn Display {
&self.op
}
fn from_op_str(op_str: &str, children: Vec<Id>) -> Result<Self, String> {
Ok(Self {
op: op_str.into(),
children,
})
}
}