use crate::*;
use crate::parse_expr::{curry_sexp,uncurry_sexp};
use std::fmt::{self, Formatter, Display, Debug};
use std::hash::Hash;
use sexp::Sexp;
use serde::{Serialize, Deserialize};
use egg::Analysis;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum Lambda {
Var(i32), IVar(i32), Prim(Symbol), App([Id; 2]), Lam([Id; 1]), Programs(Vec<Id>), }
pub const SENTINEL: usize = u32::MAX as usize;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Expr {
pub nodes: Vec<Lambda>, }
impl Display for Lambda {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Var(i) => write!(f, "${}", i),
Self::IVar(i) => write!(f, "#{}", i),
Self::Prim(p) => write!(f,"{}",p),
Self::App(_) => write!(f,"app"),
Self::Lam(_) => write!(f,"lam"),
Self::Programs(_) => write!(f,"programs"),
}
}
}
impl Language for Lambda {
fn matches(&self, other: &Self) -> bool {
match (self,other) {
(Self::Var(i), Self::Var(j)) => i == j,
(Self::IVar(i), Self::IVar(j)) => i == j,
(Self::Prim(p1), Self::Prim(p2)) => p1 == p2,
(Self::App(_), Self::App(_)) => true,
(Self::Lam(_), Self::Lam(_)) => true,
(Self::Programs(p1), Self::Programs(p2)) => p1.len() == p2.len(),
(_,_) => false,
}
}
fn children(&self) -> &[Id] {
match self {
Self::App(ids) => ids,
Self::Lam(ids) => ids,
Self::Programs(ids) => ids,
_ => &[],
}
}
fn children_mut(&mut self) -> &mut [Id] {
match self {
Self::App(ids) => ids,
Self::Lam(ids) => ids,
Self::Programs(ids) => ids,
_ => &mut [],
}
}
}
impl FromOp for Lambda {
type Error = String;
fn from_op(op: &str, children: Vec<Id>) -> Result<Self, Self::Error> {
match op {
"app" => {
if children.len() != 2 {
return Err(format!("app needs 2 children, got {}", children.len()));
}
Ok(Self::App([children[0], children[1]]))
},
"lam" => {
if children.len() != 1 {
return Err(format!("lam needs 1 child, got {}", children.len()));
}
Ok(Self::Lam([children[0]]))
}
"programs" => Ok(Self::Programs(children)),
_ => {
if !children.is_empty() {
return Err(format!("{} needs 0 children, got {}", op, children.len()))
}
if op.starts_with('$') {
let i = op.chars().skip(1).collect::<String>().parse::<i32>().unwrap();
Ok(Self::Var(i))
} else if op.starts_with('#') {
let i = op.chars().skip(1).collect::<String>().parse::<i32>().unwrap();
Ok(Self::IVar(i))
} else {
Ok(Self::Prim(egg::Symbol::from(op)))
}
},
}
}
}
impl From<RecExpr<Lambda>> for Expr {
fn from(e: RecExpr<Lambda>) -> Self {
let nodes: Vec<Lambda> = unsafe{ std::mem::transmute(e) };
Expr::new(nodes)
}
}
impl From<Expr> for RecExpr<Lambda> {
fn from(e: Expr) -> Self {
unsafe{ std::mem::transmute(e.nodes) }
}
}
impl From<&Expr> for &RecExpr<Lambda> {
fn from(e: &Expr) -> Self {
let nodes: &Vec<Lambda> = &e.nodes;
unsafe{ std::mem::transmute(nodes) }
}
}
impl Display for Expr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fn fmt_local(e: &Expr, child: Id, left_of_app: bool, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if usize::from(child) == SENTINEL {
return write!(f,"??");
}
match &e.nodes[usize::from(child)] {
Lambda::Var(_) | Lambda::IVar(_) | Lambda::Prim(_) => write!(f,"{}", &e.nodes[usize::from(child)]),
Lambda::App([fun,x]) => {
if !left_of_app { write!(f,"(")? }
fmt_local(e, *fun, true, f)?;
write!(f," ")?;
fmt_local(e, *x, false, f)?;
if !left_of_app { write!(f,")") } else { Ok(()) }
},
Lambda::Lam([b]) => {
write!(f,"(lam ")?;
fmt_local(e, *b, false, f)?;
write!(f,")")
},
Lambda::Programs(ids) => {
write!(f,"(")?;
for id in ids[..ids.len()-1].iter() {
fmt_local(e, *id, false, f)?;
write!(f," ")?;
}
fmt_local(e, *ids.last().unwrap(), false, f)?;
write!(f,")")
},
}
}
fmt_local(self, self.root(), false, f)
}
}
impl Expr {
pub fn new(nodes: Vec<Lambda>) -> Self {
Self { nodes }
}
pub fn root(&self) -> Id {
Id::from(self.nodes.len()-1)
}
pub fn get_root(&self) -> &Lambda {
self.get(self.root())
}
pub fn get(&self, child:Id) -> &Lambda {
&self.nodes[usize::from(child)]
}
pub fn var(i: i32) -> Self {
Self::new(vec![Lambda::Var(i)])
}
pub fn ivar(i: i32) -> Self {
Self::new(vec![Lambda::IVar(i)])
}
pub fn prim(p: Symbol) -> Self {
Self::new(vec![Lambda::Prim(p)])
}
pub fn app(f: Expr, mut x: Expr) -> Self {
let mut nodes = f.nodes;
let f_id = Id::from(nodes.len()-1);
x.shift_nodes(nodes.len() as i32);
nodes.extend(x.nodes);
let x_id = Id::from(nodes.len()-1);
nodes.push(Lambda::App([f_id, x_id]));
Self::new(nodes)
}
pub fn lam(b: Expr) -> Self{
let mut nodes = b.nodes.clone();
let b_id = Id::from(b.nodes.len()-1);
nodes.push(Lambda::Lam([b_id]));
Self::new(nodes)
}
pub fn programs(programs: Vec<Expr>) -> Self {
let mut nodes = vec![];
let mut root_ids = vec![];
for mut p in programs.into_iter() {
p.shift_nodes(nodes.len() as i32);
nodes.extend(p.nodes);
root_ids.push(Id::from(nodes.len() - 1));
}
nodes.push(Lambda::Programs(root_ids));
Self::new(nodes)
}
pub fn split_programs(&self) -> Vec<Expr> {
match self.get_root() {
Lambda::Programs(roots) => {
let mut res: Vec<Expr> = vec![];
let mut start: usize = 0;
for root in roots.iter() {
let end = usize::from(*root)+1;
let mut e = Expr::new(self.nodes[start..end].to_vec());
e.shift_nodes(-(start as i32));
res.push(e);
start = end;
}
res
},
_ => unreachable!()
}
}
pub fn shift_nodes(&mut self, shift: i32) {
for node in &mut self.nodes {
node.update_children(|id| Id::from((usize::from(id) as i32 + shift) as usize));
}
}
pub fn depth(&self) -> i32 {
ProgramDepth{}.cost_rec(self.into())
}
pub fn cost(&self) -> i32 {
ProgramCost{}.cost_rec(self.into())
}
pub fn cloned_subexpr(&self, child:Id) -> Self {
assert!(self.nodes.len() > child.into());
Self::new(self.nodes.iter().take(usize::from(child)+1).cloned().collect())
}
pub fn into_subexpr(mut self, child:Id) -> Self {
assert!(self.nodes.len() > child.into());
self.nodes.truncate(usize::from(child)+1);
self
}
pub fn from_curried(s: &str) -> Result<Self,String> {
let recexpr: RecExpr<Lambda> = s.parse().map_err(|e|format!("{:?}",e))?;
Ok(recexpr.into())
}
pub fn from_uncurried(s: &str) -> Result<Self,String> {
let mut sexpr: Sexp = sexp::parse(s).map_err(|e|e.to_string())?;
sexpr = curry_sexp(&sexpr);
Self::from_curried(&sexpr.to_string())
}
pub fn to_string_curried(&self, child: Option<Id>) -> String {
let expr = match child {
None => self.clone(),
Some(id) => self.cloned_subexpr(id)
};
expr.to_sexp(self.root()).to_string()
}
pub fn to_string_uncurried(&self, child:Option<Id>) -> String {
uncurry_sexp(&self.to_sexp(child.unwrap_or_else(|| self.root()))).to_string()
}
pub fn to_sexp(&self, child: Id) -> Sexp {
if usize::from(child) == SENTINEL {
return Sexp::Atom(sexp::Atom::S("??".to_string()));
}
let node = &self.nodes[usize::from(child)];
match node {
Lambda::Var(_) | Lambda::IVar(_) | Lambda::Prim(_) => sexp::parse(&node.to_string()).unwrap(),
Lambda::App([f,x]) => {
let f = self.to_sexp(*f);
let x = self.to_sexp(*x);
let app = Sexp::Atom(sexp::Atom::S("app".to_string()));
Sexp::List(vec![app,f,x])
},
Lambda::Lam([b]) => {
let b = self.to_sexp(*b);
let lam = Sexp::Atom(sexp::Atom::S("lam".to_string()));
Sexp::List(vec![lam,b])
},
Lambda::Programs(root_ids) => {
let mut res = vec![Sexp::Atom(sexp::Atom::S("programs".to_string()))];
root_ids.iter().for_each(|id| res.push(self.to_sexp(*id)));
Sexp::List(res)
}
}
}
pub fn save<A: Analysis<Lambda> + Default>(&self, name: &str, outdir: &str) {
let mut egraph: EGraph<Lambda,A> = Default::default();
egraph.add_expr(self.into());
egraph.dot().to_png(format!("{}/{}.png",outdir,name)).unwrap();
}
}
const COST_NONTERMINAL:i32 = 1;
const COST_TERMINAL:i32 = 100;
pub struct ProgramCost {}
impl CostFunction<Lambda> for ProgramCost {
type Cost = i32;
fn cost<C>(&mut self, enode: &Lambda, mut costs: C) -> Self::Cost
where
C: FnMut(Id) -> Self::Cost
{
match enode {
Lambda::Var(_) | Lambda::IVar(_) | Lambda::Prim(_) => COST_TERMINAL,
Lambda::App([f, x]) => {
COST_NONTERMINAL + costs(*f) + costs(*x)
}
Lambda::Lam([b]) => {
COST_NONTERMINAL + costs(*b)
}
Lambda::Programs(ps) => {
ps.iter()
.map(|p|costs(*p))
.sum()
}
}
}
}
pub struct ProgramDepth {}
impl CostFunction<Lambda> for ProgramDepth {
type Cost = i32;
fn cost<C>(&mut self, enode: &Lambda, mut costs: C) -> Self::Cost
where
C: FnMut(Id) -> Self::Cost
{
match enode {
Lambda::Var(_) | Lambda::IVar(_) | Lambda::Prim(_) => 1,
Lambda::App([f, x]) => {
1 + std::cmp::max(costs(*f), costs(*x))
}
Lambda::Lam([b]) => {
1 + costs(*b)
}
Lambda::Programs(ps) => {
ps.iter()
.map(|p|costs(*p))
.max().unwrap()
}
}
}
}