use crate::abstraction_learning::*;
use crate::abstraction_learning::egraphs::EGraph;
use lambdas::*;
use rustc_hash::{FxHashMap};
use std::hash::Hash;
pub fn programs_info(programs: &[Expr]) {
let max_cost = programs.iter().map(|p| p.cost()).max().unwrap();
let max_depth = programs.iter().map(|p| p.depth()).max().unwrap();
println!("Programs:");
println!("\t num: {}",programs.len());
println!("\t max cost: {}",max_cost);
println!("\t max depth: {}",max_depth);
}
pub fn timestamp() -> String {
format!("{}", chrono::Local::now().format("%Y-%m-%d_%H-%M-%S"))
}
pub fn save(egraph: &EGraph, name: &str, outdir: &str)
{
egraph.dot().to_png(format!("{}/{}.png",outdir,name)).unwrap();
}
pub fn egraph_info(egraph: &EGraph) -> String
{
format!("{} nodes, {} classes, {} memo", egraph.total_number_of_nodes(), egraph.number_of_classes(), egraph.total_size())
}
pub fn compression_factor(original: &Expr, compressed: &Expr) -> f64 {
f64::from(original.cost())/f64::from(compressed.cost())
}
pub fn ivar_replace(e: &Expr, child: Id, map: &FxHashMap<i32, Expr>) -> Expr {
match e.get(child) {
Lambda::IVar(i) => map.get(i).unwrap_or(e).clone(),
Lambda::Var(v) => Expr::var(*v),
Lambda::Prim(p) => Expr::prim(*p),
Lambda::App([f,x]) => Expr::app(ivar_replace(e, *f, map), ivar_replace(e, *x, map)),
Lambda::Lam([b]) => Expr::lam(ivar_replace(e, *b, map)),
Lambda::Programs(_) => panic!("why would you do this")
}
}
pub fn ivar_to_dc(e: &Expr, child: Id, depth: i32, arity: i32) -> Expr {
match e.get(child) {
Lambda::IVar(i) => Expr::var(depth + (arity - 1 - i)), Lambda::Var(v) => Expr::var(*v),
Lambda::Prim(p) => Expr::prim(*p),
Lambda::App([f,x]) => Expr::app(ivar_to_dc(e, *f, depth, arity), ivar_to_dc(e, *x, depth, arity)),
Lambda::Lam([b]) => Expr::lam(ivar_to_dc(e, *b, depth+1, arity)),
Lambda::Programs(_) => panic!("why would you do this")
}
}
pub fn dc_inv_str(inv: &Invention, dreamcoder_translations: &[(String, String)]) -> String {
let mut body: Expr = ivar_to_dc(&inv.body, inv.body.root(), 0, inv.arity as i32);
for _ in 0..inv.arity {
body = Expr::lam(body);
}
let mut res: String = format!("#{}", body);
res = res.replace("(lam ", "(lambda ");
for (inv_name, dc_translation) in dreamcoder_translations.iter() {
res = replace_prim_with(&res, inv_name, dc_translation);
}
res
}
pub fn replace_prim_with(s: &str, prim: &str, new: &str) -> String {
let mut res: String = s.to_string();
res = res.replace(&format!(" {})",prim), &format!(" {})",new));
res = res.replace(&format!(" {} ",prim), &format!(" {} ",new));
res = res.replace(&format!(" {} ",prim), &format!(" {} ",new));
assert!(!res.contains(&format!(" {} ",prim)));
res = res.replace(&format!("({} ",prim), &format!("({} ",new));
if res.starts_with(&format!("{} ",prim)) {
res = format!("{} {}", new, &res[prim.len()..]);
}
if res.ends_with(&format!(" {}",prim)) {
res = format!("{} {}", &res[..res.len()-prim.len()], new);
}
if res == prim {
res = new.to_string();
}
res
}
pub type RecVarModCache = FxHashMap<(Id,i32),Option<Id>>;
pub fn recursive_var_mod(
var_mod: impl Fn(i32, i32, i32, &mut EGraph) -> Option<Id>,
ivars: bool,
eclass:Id,
egraph: &mut EGraph,
seen: &mut RecVarModCache
) -> Option<Id>
{
recursive_var_mod_helper(
&var_mod,
ivars,
eclass,
0,
egraph,
seen,
)
}
fn recursive_var_mod_helper(
var_mod: &impl Fn(i32, i32, i32, &mut EGraph) -> Option<Id>,
ivars: bool, eclass:Id,
depth: i32,
egraph: &mut EGraph,
seen : &mut RecVarModCache,
) -> Option<Id>
{
let eclass = egraph.find(eclass);
let key = (eclass,depth);
if seen.contains_key(&key) {
return seen[&key];
}
if (ivars && egraph[eclass].data.free_ivars.is_empty())
|| (!ivars && egraph[eclass].data.free_vars.iter().all(|i| *i < depth)) {
seen.insert(key, Some(eclass));
return Some(eclass)
}
seen.insert(key, None);
assert!(egraph[eclass].nodes.len() == 1);
let enode = egraph[eclass].nodes[0].clone();
let new_eclass = match enode {
Lambda::Var(i) => {
if ivars {
panic!("unreachable, Var doesnt have free IVars")
}
assert!(i >= depth); var_mod(i, depth, i-depth, egraph)
}
Lambda::IVar(i) => {
if !ivars {
panic!("unreachable, IVar doesnt have free Vars")
}
var_mod(i, depth, i-depth, egraph)
}
Lambda::Prim(_) => {
panic!("unreachable, Prim never has free vars/ivars")
}
Lambda::App([f, x]) => {
let fnew_opt = recursive_var_mod_helper(var_mod, ivars, f, depth, egraph, seen);
let xnew_opt = recursive_var_mod_helper(var_mod, ivars, x, depth, egraph, seen);
match (fnew_opt,xnew_opt) {
(Some(fnew),Some(xnew)) => Some(egraph.add(Lambda::App([fnew, xnew]))),
_ => None,
}
}
Lambda::Lam([b]) => {
recursive_var_mod_helper(var_mod, ivars, b, depth+1, egraph, seen)
.map(|bnew| egraph.add(Lambda::Lam([bnew])))
}
Lambda::Programs(_) => {
panic!("attempted to shift a Programs node")
}
};
if let Some(new_eclass) = new_eclass {
let new_eclass = egraph.find(new_eclass);
seen.insert(key, Some(new_eclass));
Some(new_eclass)
} else {
None
}
}
#[inline]
pub fn group_by_key<T: Copy, U: Ord>(v: Vec<T>, key: impl Fn(&T)->U) -> Vec<Vec<T>> {
let mut group = vec![v[0]];
let mut groups = vec![];
for i in 1..v.len() {
if key(&v[i]) == key(&v[i-1]) {
group.push(v[i]);
} else {
groups.push(group);
group = vec![v[i]];
}
}
groups.push(group);
groups
}
pub fn num_paths_to_node(roots: &[Id], treenodes: &[Id], egraph: &EGraph) -> (Vec<i32>, Vec<Vec<i32>>) {
let mut num_paths_to_node_by_root_idx: Vec<Vec<i32>> = vec![vec![0; treenodes.len()]; roots.len()];
fn helper(num_paths_to_node: &mut Vec<i32>, node: &Id, egraph: &EGraph) {
num_paths_to_node[usize::from(*node)] += 1;
for child in egraph[*node].nodes[0].children() {
helper(num_paths_to_node, child, egraph);
}
}
let mut num_paths_to_node_all: Vec<i32> = vec![0; treenodes.len()];
num_paths_to_node_by_root_idx.iter_mut().enumerate().for_each(|(i,num_paths_to_node)| {
helper(num_paths_to_node, &roots[i], egraph);
for i in 0..treenodes.len() {
num_paths_to_node_all[i] += num_paths_to_node[i];
}
});
(num_paths_to_node_all, num_paths_to_node_by_root_idx)
}
pub fn counts_ahash<T: Hash + Eq + Clone>(v: &[T]) -> FxHashMap<T, usize>
{
let mut counts = FxHashMap::default();
v.iter().for_each(|item| *counts.entry(item.clone()).or_default() += 1);
counts
}