use crate::abstraction_learning::*;
use crate::abstraction_learning::egraphs::EGraph;
use lambdas::*;
use rustc_hash::{FxHashMap,FxHashSet};
use std::fmt::{self, Formatter, Display};
use std::hash::Hash;
use itertools::Itertools;
use rewriting::extract;
use serde_json::json;
use clap::{Parser};
use serde::Serialize;
use std::thread;
use std::sync::Arc;
use parking_lot::Mutex;
use std::ops::DerefMut;
use std::collections::BinaryHeap;
use rand::Rng;
#[derive(Parser, Debug, Serialize, Clone)]
#[clap(name = "Stitch")]
pub struct CompressionStepConfig {
#[clap(short='a', long, default_value = "2")]
pub max_arity: usize,
#[clap(short='t', long, default_value = "1")]
pub threads: usize,
#[clap(long)]
pub no_stats: bool,
#[clap(short='b', long, default_value = "1")]
pub batch: usize,
#[clap(long)]
pub dynamic_batch: bool,
#[clap(short='n', long, default_value = "1")]
pub inv_candidates: usize,
#[clap(long, arg_enum, default_value = "depth-first")]
pub hole_choice: HoleChoice,
#[clap(long)]
pub no_mismatch_check: bool,
#[clap(long)]
pub no_top_lambda: bool,
#[clap(long)]
pub track: Option<String>,
#[clap(long)]
pub follow_track: bool,
#[clap(long)]
pub verbose_worklist: bool,
#[clap(long)]
pub verbose_best: bool,
#[clap(long, default_value = "0")]
pub print_stats: usize,
#[clap(long,short='r')]
pub show_rewritten: bool,
#[clap(long)]
pub no_opt_free_vars: bool,
#[clap(long)]
pub no_opt_single_use: bool,
#[clap(long)]
pub no_opt_single_task: bool,
#[clap(long)]
pub no_opt_upper_bound: bool,
#[clap(long)]
pub no_opt_force_multiuse: bool,
#[clap(long)]
pub no_opt_useless_abstract: bool,
#[clap(long)]
pub no_opt_arity_zero: bool,
#[clap(long)]
pub no_other_util: bool,
#[clap(long)]
pub rewrite_check: bool,
#[clap(long)]
pub utility_by_rewrite: bool,
#[clap(long)]
pub dreamcoder_comparison: bool,
}
impl CompressionStepConfig {
pub fn no_opt(&mut self) {
self.no_opt_free_vars = true;
self.no_opt_single_task = true;
self.no_opt_upper_bound = true;
self.no_opt_force_multiuse = true;
self.no_opt_useless_abstract = true;
self.no_opt_arity_zero = true;
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Pattern {
pub holes: Vec<ZId>, arg_choices: Vec<LabelledZId>, pub first_zid_of_ivar: Vec<ZId>, pub match_locations: Vec<Id>, pub utility_upper_bound: i32,
pub body_utility: i32, pub tracked: bool, }
#[allow(clippy::ptr_arg)]
fn zipper_replace(expr: &Expr, zip: &Zip, new: &str) -> Expr {
let child = apply_zipper(expr,zip).unwrap();
let mut res = expr.clone();
res.nodes[usize::from(child)] = Lambda::Prim(new.into());
res
}
#[allow(clippy::ptr_arg)]
fn apply_zipper(expr: &Expr, zip: &Zip) -> Option<Id> {
let mut child = expr.root();
for znode in zip.iter() {
child = match (znode, expr.get(child)) {
(ZNode::Body, Lambda::Lam([b])) => *b,
(ZNode::Func, Lambda::App([f,_])) => *f,
(ZNode::Arg, Lambda::App([_,x])) => *x,
(_,_) => return None };
}
Some(child)
}
fn zids_of_ivar_of_expr(expr: &Expr, zid_of_zip: &FxHashMap<Zip,ZId>) -> Vec<Vec<ZId>> {
let mut arity = 0;
for node in expr.nodes.iter() {
if let Lambda::IVar(ivar) = node {
if ivar + 1 > arity {
arity = ivar + 1;
}
}
}
let mut curr_zip: Zip = vec![];
let mut zids_of_ivar = vec![vec![]; arity as usize];
fn helper(curr_node: Id, expr: &Expr, curr_zip: &mut Zip, zids_of_ivar: &mut Vec<Vec<ZId>>, zid_of_zip: &FxHashMap<Zip,ZId>) {
match expr.get(curr_node) {
Lambda::Prim(_) => {},
Lambda::Var(_) => {},
Lambda::IVar(i) => {
zids_of_ivar[*i as usize].push(zid_of_zip[curr_zip]);
},
Lambda::Lam([b]) => {
curr_zip.push(ZNode::Body);
helper(*b, expr, curr_zip, zids_of_ivar, zid_of_zip);
curr_zip.pop();
}
Lambda::App([f,x]) => {
curr_zip.push(ZNode::Func);
helper(*f, expr, curr_zip, zids_of_ivar, zid_of_zip);
curr_zip.pop();
curr_zip.push(ZNode::Arg);
helper(*x, expr, curr_zip, zids_of_ivar, zid_of_zip);
curr_zip.pop();
}
_ => unreachable!(),
}
}
helper(expr.root(), expr, &mut curr_zip, &mut zids_of_ivar, zid_of_zip);
zids_of_ivar
}
impl Pattern {
fn single_hole(treenodes: &[Id], cost_of_node_all: &[i32], num_paths_to_node: &[i32], egraph: &EGraph, cfg: &CompressionStepConfig) -> Self {
let body_utility = 0;
let mut match_locations = treenodes.to_owned();
match_locations.sort(); if cfg.no_top_lambda {
match_locations.retain(|node| expands_to_of_node(&egraph[*node].nodes[0]) != ExpandsTo::Lam);
}
let utility_upper_bound = utility_upper_bound(&match_locations, body_utility, cost_of_node_all, num_paths_to_node, cfg);
Pattern {
holes: vec![EMPTY_ZID], arg_choices: vec![],
first_zid_of_ivar: vec![],
match_locations, utility_upper_bound,
body_utility, tracked: cfg.track.is_some(),
}
}
fn to_expr(&self, shared: &SharedData) -> Expr {
let mut curr_zip: Zip = vec![];
let zips: Vec<(Zip,Expr)> = self.holes.iter().map(|zid| (shared.zip_of_zid[*zid].clone(), Expr::prim("??".into())))
.chain(self.arg_choices.iter()
.map(|labelled_zid| (shared.zip_of_zid[labelled_zid.zid].clone(), Expr::ivar(labelled_zid.ivar as i32)))).collect();
fn helper(curr_node: Id, curr_zip: &mut Zip, zips: &[(Zip,Expr)], shared: &SharedData) -> Expr {
match zips.iter().find(|(zip,_)| zip == curr_zip) {
Some((_,e)) => e.clone(),
None => {
match &shared.node_of_id[usize::from(curr_node)] {
Lambda::Prim(p) => Expr::prim(*p),
Lambda::Var(v) => Expr::var(*v),
Lambda::Lam([b]) => {
curr_zip.push(ZNode::Body);
let b_expr = helper(*b, curr_zip, zips, shared);
curr_zip.pop();
Expr::lam(b_expr)
}
Lambda::App([f,x]) => {
curr_zip.push(ZNode::Func);
let f_expr = helper(*f, curr_zip, zips, shared);
curr_zip.pop();
curr_zip.push(ZNode::Arg);
let x_expr = helper(*x, curr_zip, zips, shared);
curr_zip.pop();
Expr::app(f_expr, x_expr)
}
_ => unreachable!(),
}
}
}
}
helper(self.match_locations[0], &mut curr_zip, &zips, shared)
}
fn show_track_expansion(&self, hole_zid: ZId, shared: &SharedData) -> String {
let mut s = zipper_replace(&self.to_expr(shared), &shared.zip_of_zid[hole_zid], "<REPLACE>" ).to_string();
s = s.replace(&"<REPLACE>", &format!("{}",tracked_expands_to(self, hole_zid, shared)).magenta().bold().to_string());
s
}
pub fn info(&self, shared: &SharedData) -> String {
format!("{}: utility_upper_bound={}, body_utility={}, match_locations={}, usages={}",self.to_expr(shared), self.utility_upper_bound, self.body_utility, self.match_locations.len(), self.match_locations.iter().map(|loc|shared.num_paths_to_node[usize::from(*loc)]).sum::<i32>())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub enum ExpandsTo {
Lam,
App,
Var(i32),
Prim(Symbol),
IVar(i32),
}
impl ExpandsTo {
#[inline]
#[allow(dead_code)]
fn has_holes(&self) -> bool {
match self {
ExpandsTo::Lam => true,
ExpandsTo::App => true,
ExpandsTo::Var(_) => false,
ExpandsTo::Prim(_) => false,
ExpandsTo::IVar(_) => false,
}
}
#[inline]
#[allow(dead_code)]
fn is_ivar(&self) -> bool {
matches!(self, ExpandsTo::IVar(_))
}
}
impl std::fmt::Display for ExpandsTo {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
ExpandsTo::Lam => write!(f, "(lam ??)"),
ExpandsTo::App => write!(f, "(?? ??)"),
ExpandsTo::Var(v) => write!(f, "${}", v),
ExpandsTo::Prim(p) => write!(f, "{}", p),
ExpandsTo::IVar(v) => write!(f, "#{}", v),
}
}
}
pub type Zip = Vec<ZNode>;
const EMPTY_ZID: ZId = 0;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Arg {
pub shifted_id: Id,
pub unshifted_id: Id, pub shift: i32,
pub cost: i32,
pub expands_to: ExpandsTo,
}
fn expands_to_of_node(node: &Lambda) -> ExpandsTo {
match node {
Lambda::Var(i) => ExpandsTo::Var(*i),
Lambda::Prim(p) => {
if *p == Symbol::from("?#") {
panic!("I still need to handle this") } else {
ExpandsTo::Prim(*p)
}
},
Lambda::Lam(_) => ExpandsTo::Lam,
Lambda::App(_) => ExpandsTo::App,
Lambda::IVar(i) => ExpandsTo::IVar(*i),
_ => unreachable!()
}
}
fn tracked_expands_to(pattern: &Pattern, hole_zid: ZId, shared: &SharedData) -> ExpandsTo {
let id = apply_zipper(&shared.tracking.as_ref().unwrap().expr, &shared.zip_of_zid[hole_zid]).unwrap();
match expands_to_of_node(shared.tracking.as_ref().unwrap().expr.get(id)) {
ExpandsTo::IVar(i) => {
let zids = shared.tracking.as_ref().unwrap().zids_of_ivar[i as usize].clone();
for (j,zid) in pattern.first_zid_of_ivar.iter().enumerate() {
if zids.contains(zid) {
return ExpandsTo::IVar(j as i32);
}
}
ExpandsTo::IVar(pattern.first_zid_of_ivar.len() as i32)
}
e => e
}
}
#[derive(Debug,Clone, Eq, PartialEq)]
pub struct HeapItem {
key: i32,
pattern: Pattern,
}
impl PartialOrd for HeapItem {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.key.partial_cmp(&other.key)
}
}
impl Ord for HeapItem {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.key.cmp(&other.key)
}
}
impl HeapItem {
fn new(pattern: Pattern) -> Self {
HeapItem {
key: pattern.utility_upper_bound,
pattern
}
}
}
#[derive(Debug, Clone)]
pub struct CriticalMultithreadData {
donelist: Vec<FinishedPattern>,
worklist: BinaryHeap<HeapItem>,
utility_pruning_cutoff: i32,
active_threads: FxHashSet<std::thread::ThreadId>, }
#[derive(Debug)]
pub struct SharedData {
pub crit: Mutex<CriticalMultithreadData>,
pub arg_of_zid_node: Vec<FxHashMap<Id,Arg>>,
pub treenodes: Vec<Id>,
pub node_of_id: Vec<Lambda>,
pub programs_node: Id,
pub roots: Vec<Id>,
pub zids_of_node: FxHashMap<Id,Vec<ZId>>,
pub zip_of_zid: Vec<Zip>,
pub zid_of_zip: FxHashMap<Zip, ZId>,
pub extensions_of_zid: Vec<ZIdExtension>,
pub egraph: EGraph,
pub num_paths_to_node: Vec<i32>,
pub num_paths_to_node_by_root_idx: Vec<Vec<i32>>,
pub tasks_of_node: Vec<FxHashSet<usize>>,
pub task_name_of_task: Vec<String>,
pub task_of_root_idx: Vec<usize>,
pub root_idxs_of_task: Vec<Vec<usize>>,
pub cost_of_node_once: Vec<i32>,
pub cost_of_node_all: Vec<i32>,
pub free_vars_of_node: Vec<FxHashSet<i32>>,
pub init_cost: i32,
pub init_cost_by_root_idx: Vec<i32>,
pub stats: Mutex<Stats>,
pub cfg: CompressionStepConfig,
pub tracking: Option<Tracking>,
}
#[derive(Debug)]
pub struct Tracking {
expr: Expr,
zids_of_ivar: Vec<Vec<ZId>>,
}
impl CriticalMultithreadData {
fn new(donelist: Vec<FinishedPattern>, treenodes: &[Id], cost_of_node_all: &[i32], num_paths_to_node: &[i32], egraph: &EGraph, cfg: &CompressionStepConfig) -> Self {
let mut worklist = BinaryHeap::new();
worklist.push(HeapItem::new(Pattern::single_hole(treenodes, cost_of_node_all, num_paths_to_node, egraph, cfg)));
let mut res = CriticalMultithreadData {
donelist,
worklist,
utility_pruning_cutoff: 0,
active_threads: FxHashSet::default(),
};
res.update(cfg);
res
}
fn update(&mut self, cfg: &CompressionStepConfig) {
self.donelist.sort_unstable_by(|a,b| (b.utility,&b.pattern.arg_choices).cmp(&(a.utility,&a.pattern.arg_choices)));
self.donelist.truncate(cfg.inv_candidates);
self.utility_pruning_cutoff = if cfg.no_opt_upper_bound { 0 } else { std::cmp::max(0,self.donelist.last().map(|x|x.utility).unwrap_or(0)) };
}
}
#[derive(Debug, Clone)]
pub struct Invention {
pub body: Expr, pub arity: usize,
pub name: String,
}
impl Invention {
pub fn new(body: Expr, arity: usize, name: &str) -> Self {
Self { body, arity, name: String::from(name) }
}
pub fn apply(&self, args: &[Expr]) -> Expr {
assert_eq!(args.len(), self.arity);
let map: FxHashMap<i32, Expr> = args.iter().enumerate().map(|(i,e)| (i as i32, e.clone())).collect();
ivar_replace(&self.body, self.body.root(), &map)
}
}
impl Display for Invention {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "[{} arity={}: {}]", self.name, self.arity, self.body)
}
}
#[derive(Debug, Clone, Eq, PartialEq, Hash, PartialOrd, Ord)]
pub enum ZNode {
Func, Body,
Arg, }
pub type ZId = usize;
#[derive(Debug,Clone, Eq, PartialEq, Hash, PartialOrd, Ord)]
struct LabelledZId {
zid: ZId,
ivar: usize }
#[derive(Clone,Default, Debug)]
pub struct Stats {
worklist_steps: usize,
finished: usize,
calc_final_utility: usize,
upper_bound_fired: usize,
free_vars_fired: usize,
single_use_fired: usize,
single_task_fired: usize,
useless_abstract_fired: usize,
force_multiuse_fired: usize,
}
#[derive(Debug, Clone, clap::ArgEnum, Serialize)]
pub enum HoleChoice {
Random,
BreadthFirst,
DepthFirst,
MaxLargestSubset,
HighEntropy,
LowEntropy,
MaxCost,
MinCost,
ManyGroups,
FewGroups,
FewApps,
}
impl HoleChoice {
fn choose_hole(&self, pattern: &Pattern, shared: &SharedData) -> usize {
if pattern.holes.len() == 1 {
return 0;
}
match *self {
HoleChoice::BreadthFirst => 0,
HoleChoice::DepthFirst => pattern.holes.len() - 1,
HoleChoice::Random => {
let mut rng = rand::thread_rng();
rng.gen_range(0..pattern.holes.len())
},
HoleChoice::FewApps => {
pattern.holes.iter().enumerate().map(|(hole_idx,hole_zid)|
(hole_idx, pattern.match_locations.iter().filter(|loc|shared.arg_of_zid_node[*hole_zid][loc].expands_to == ExpandsTo::App).count()))
.min_by_key(|x|x.1).unwrap().0
}
HoleChoice::MaxCost => {
pattern.holes.iter().enumerate().map(|(hole_idx,hole_zid)|
(hole_idx, pattern.match_locations.iter().map(|loc|shared.arg_of_zid_node[*hole_zid][loc].cost).sum::<i32>()))
.max_by_key(|x|x.1).unwrap().0
}
HoleChoice::MinCost => {
pattern.holes.iter().enumerate().map(|(hole_idx,hole_zid)|
(hole_idx, pattern.match_locations.iter().map(|loc|shared.arg_of_zid_node[*hole_zid][loc].cost).sum::<i32>()))
.min_by_key(|x|x.1).unwrap().0
}
HoleChoice::MaxLargestSubset => {
pattern.holes.iter().enumerate()
.map(|(hole_idx,hole_zid)| (hole_idx, *pattern.match_locations.iter()
.map(|loc| shared.arg_of_zid_node[*hole_zid][loc].expands_to.clone()).counts().values().max().unwrap())).max_by_key(|&(_,max_count)| max_count).unwrap().0
}
_ => unimplemented!()
}
}
}
impl LabelledZId {
fn new(zid: ZId, ivar: usize) -> LabelledZId {
LabelledZId { zid, ivar }
}
}
#[derive(Clone,Debug)]
pub struct ZIdExtension {
body: Option<ZId>,
arg: Option<ZId>,
func: Option<ZId>,
}
fn get_worklist_item(
worklist_buf: &mut Vec<HeapItem>,
donelist_buf: &mut Vec<FinishedPattern>,
shared: &Arc<SharedData>,
) -> Option<(Vec<Pattern>,i32)> {
let mut shared_guard = shared.crit.lock();
let mut crit: &mut CriticalMultithreadData = shared_guard.deref_mut();
let old_best_utility = crit.donelist.first().map(|x|x.utility).unwrap_or(0);
let old_donelist_len = crit.donelist.len();
let old_utility_pruning_cutoff = crit.utility_pruning_cutoff;
crit.donelist.extend(donelist_buf.drain(..).filter(|done| done.utility > old_utility_pruning_cutoff));
if !shared.cfg.no_stats { shared.stats.lock().deref_mut().finished += crit.donelist.len() - old_donelist_len; };
crit.update(&shared.cfg);
if shared.cfg.verbose_best && crit.donelist.first().map(|x|x.utility).unwrap_or(0) > old_best_utility {
println!("{} @ step={} util={} for {}", "[new best utility]".blue(), shared.stats.lock().deref_mut().worklist_steps, crit.donelist.first().unwrap().utility, crit.donelist.first().unwrap().info(shared));
}
let mut utility_pruning_cutoff = crit.utility_pruning_cutoff;
let old_worklist_len = crit.worklist.len();
let worklist_buf_len = worklist_buf.len();
crit.worklist.extend(worklist_buf.drain(..).filter(|heap_item| heap_item.pattern.utility_upper_bound > utility_pruning_cutoff));
if !shared.cfg.no_stats { shared.stats.lock().deref_mut().upper_bound_fired += worklist_buf_len - (crit.worklist.len() - old_worklist_len); };
let mut returned_items = vec![];
crit.active_threads.remove(&thread::current().id());
loop {
let batch_size = if shared.cfg.dynamic_batch { std::cmp::max(1, crit.worklist.len() / shared.cfg.threads ) } else { shared.cfg.batch };
while crit.worklist.is_empty() {
if !returned_items.is_empty() {
crit.active_threads.insert(thread::current().id());
return Some((returned_items, utility_pruning_cutoff));
}
if crit.active_threads.is_empty() {
return None }
drop(shared_guard);
shared_guard = shared.crit.lock();
crit = shared_guard.deref_mut();
utility_pruning_cutoff = crit.utility_pruning_cutoff;
}
let heap_item = crit.worklist.pop().unwrap();
if shared.cfg.no_opt_upper_bound || heap_item.pattern.utility_upper_bound > utility_pruning_cutoff {
returned_items.push(heap_item.pattern);
if returned_items.len() == batch_size {
crit.active_threads.insert(thread::current().id());
return Some((returned_items, utility_pruning_cutoff));
}
} else if !shared.cfg.no_stats { shared.stats.lock().deref_mut().upper_bound_fired += 1; }
}
}
fn stitch_search(
shared: Arc<SharedData>,
) {
let mut worklist_buf: Vec<HeapItem> = Default::default();
let mut donelist_buf: Vec<_> = Default::default();
loop {
let (patterns, mut weak_utility_pruning_cutoff) =
match get_worklist_item(
&mut worklist_buf,
&mut donelist_buf,
&shared,
) {
Some(pattern) => pattern,
None => return,
};
for original_pattern in patterns {
if !shared.cfg.no_stats { shared.stats.lock().deref_mut().worklist_steps += 1; };
if !shared.cfg.no_stats && shared.cfg.print_stats > 0 && shared.stats.lock().deref_mut().worklist_steps % shared.cfg.print_stats == 0 { println!("{:?} \n\t@ [bound={}; uses={}] chose: {}",shared.stats.lock().deref_mut(), original_pattern.utility_upper_bound, original_pattern.match_locations.iter().map(|loc| shared.num_paths_to_node[usize::from(*loc)]).sum::<i32>(), original_pattern.to_expr(&shared)); };
if shared.cfg.verbose_worklist {
println!("[bound={}; uses={}] chose: {}", original_pattern.utility_upper_bound, original_pattern.match_locations.iter().map(|loc| shared.num_paths_to_node[usize::from(*loc)]).sum::<i32>(), original_pattern.to_expr(&shared));
}
let hole_idx: usize = shared.cfg.hole_choice.choose_hole(&original_pattern, &shared);
let mut holes_after_pop: Vec<ZId> = original_pattern.holes.clone();
let hole_zid: ZId = holes_after_pop.remove(hole_idx);
let arg_of_loc = &shared.arg_of_zid_node[hole_zid];
let mut match_locations = original_pattern.match_locations.clone();
match_locations.sort_by_cached_key(|loc| (&arg_of_loc[loc].expands_to, *loc));
let ivars_expansions = get_ivars_expansions(&original_pattern, arg_of_loc, &shared);
let mut found_tracked = false;
'expansion:
for (expands_to, locs) in match_locations.into_iter()
.group_by(|loc| &arg_of_loc[loc].expands_to).into_iter()
.map(|(expands_to, locs)| (expands_to.clone(), locs.collect::<Vec<Id>>()))
.chain(ivars_expansions.into_iter())
{
let tracked = original_pattern.tracked && expands_to == tracked_expands_to(&original_pattern, hole_zid, &shared);
if tracked { found_tracked = true; }
if shared.cfg.follow_track && !tracked { continue 'expansion; }
if !shared.cfg.no_opt_single_use && !shared.cfg.no_opt_arity_zero && locs.len() == 1 && shared.free_vars_of_node[usize::from(locs[0])].is_empty() {
if !shared.cfg.no_stats { shared.stats.lock().deref_mut().single_use_fired += 1; }
continue 'expansion;
}
if !shared.cfg.no_opt_single_task
&& locs.iter().all(|node| shared.tasks_of_node[usize::from(*node)].len() == 1)
&& locs.iter().all(|node| shared.tasks_of_node[usize::from(locs[0])].iter().next() == shared.tasks_of_node[usize::from(*node)].iter().next()) {
if !shared.cfg.no_stats { shared.stats.lock().deref_mut().single_task_fired += 1; }
if tracked { println!("{} single task pruned when expanding {} to {}", "[TRACK]".red().bold(), original_pattern.to_expr(&shared), zipper_replace(&original_pattern.to_expr(&shared), &shared.zip_of_zid[hole_zid], &format!("<{}>",expands_to))); }
continue 'expansion;
}
if true { if let ExpandsTo::Var(i) = expands_to {
if i >= shared.zip_of_zid[hole_zid].iter().filter(|znode|**znode == ZNode::Body).count() as i32 {
if !shared.cfg.no_stats { shared.stats.lock().deref_mut().free_vars_fired += 1; };
if tracked { println!("{} pruned by free var in body when expanding {} to {}", "[TRACK]".red().bold(), original_pattern.to_expr(&shared), original_pattern.show_track_expansion(hole_zid, &shared)); }
continue 'expansion; }
}
}
if !shared.cfg.no_opt_useless_abstract {
for argchoice in original_pattern.arg_choices.iter(){
if locs.iter().map(|loc| shared.arg_of_zid_node[argchoice.zid][loc].shifted_id).all_equal()
{
if !shared.cfg.no_stats { shared.stats.lock().deref_mut().useless_abstract_fired += 1; };
continue 'expansion; }
}
}
let body_utility = original_pattern.body_utility + match expands_to {
ExpandsTo::Lam | ExpandsTo::App => COST_NONTERMINAL,
ExpandsTo::Var(_) | ExpandsTo::Prim(_) => COST_TERMINAL,
ExpandsTo::IVar(_) => 0,
};
let util_upper_bound: i32 = utility_upper_bound(&locs, body_utility, &shared.cost_of_node_all, &shared.num_paths_to_node, &shared.cfg);
assert!(util_upper_bound <= original_pattern.utility_upper_bound);
if !shared.cfg.no_opt_upper_bound && util_upper_bound <= weak_utility_pruning_cutoff {
if !shared.cfg.no_stats { shared.stats.lock().deref_mut().upper_bound_fired += 1; };
if tracked { println!("{} upper bound ({} < {}) pruned when expanding {} to {}", "[TRACK]".red().bold(), util_upper_bound, weak_utility_pruning_cutoff, original_pattern.to_expr(&shared), original_pattern.show_track_expansion(hole_zid, &shared)); }
continue 'expansion; }
let mut holes = holes_after_pop.clone();
match expands_to {
ExpandsTo::Lam => {
holes.push(shared.extensions_of_zid[hole_zid].body.unwrap());
}
ExpandsTo::App => {
holes.push(shared.extensions_of_zid[hole_zid].func.unwrap());
holes.push(shared.extensions_of_zid[hole_zid].arg.unwrap());
}
_ => {}
}
let mut arg_choices = original_pattern.arg_choices.clone();
let mut first_zid_of_ivar = original_pattern.first_zid_of_ivar.clone();
if let ExpandsTo::IVar(i) = expands_to {
arg_choices.push(LabelledZId::new(hole_zid, i as usize));
if i as usize == original_pattern.first_zid_of_ivar.len() {
first_zid_of_ivar.push(hole_zid);
}
}
if !shared.cfg.no_opt_force_multiuse {
for (i,ivar_zid_1) in first_zid_of_ivar.iter().enumerate() {
let arg_of_loc_1 = &shared.arg_of_zid_node[*ivar_zid_1];
for ivar_zid_2 in first_zid_of_ivar.iter().skip(i+1) {
let arg_of_loc_2 = &shared.arg_of_zid_node[*ivar_zid_2];
if locs.iter().all(|loc|
arg_of_loc_1[loc].shifted_id == arg_of_loc_2[loc].shifted_id)
{
if !shared.cfg.no_stats { shared.stats.lock().deref_mut().force_multiuse_fired += 1; };
if tracked { println!("{} force multiuse pruned when expanding {} to {}", "[TRACK]".red().bold(), original_pattern.to_expr(&shared), original_pattern.show_track_expansion(hole_zid, &shared)); }
continue 'expansion;
}
}
}
}
let new_pattern = Pattern {
holes,
arg_choices,
first_zid_of_ivar,
match_locations: locs,
utility_upper_bound: util_upper_bound,
body_utility,
tracked
};
if new_pattern.holes.is_empty() {
let finished_pattern = FinishedPattern::new(new_pattern, &shared);
if !shared.cfg.no_stats { shared.stats.lock().calc_final_utility += 1; };
if shared.cfg.rewrite_check {
rewrite_fast(&finished_pattern, &shared, "fake_inv");
}
if tracked {
println!("{} pushed {} to donelist (util: {})", "[TRACK:DONE]".green().bold(), finished_pattern.to_expr(&shared), finished_pattern.utility);
}
if shared.cfg.inv_candidates == 1 && finished_pattern.utility > weak_utility_pruning_cutoff {
weak_utility_pruning_cutoff = finished_pattern.utility;
}
donelist_buf.push(finished_pattern);
} else {
if tracked { println!("{} pushed {} to work list (bound: {})", "[TRACK]".green().bold(), original_pattern.show_track_expansion(hole_zid, &shared), new_pattern.utility_upper_bound); }
worklist_buf.push(HeapItem::new(new_pattern))
}
}
if original_pattern.tracked && !found_tracked {
println!("{} pruned when expanding because there were no match locations for the target expansion of {} to {}", "[TRACK]".red().bold(), original_pattern.to_expr(&shared), original_pattern.show_track_expansion(hole_zid, &shared));
}
}
}
}
fn get_ivars_expansions(original_pattern: &Pattern, arg_of_loc: &FxHashMap<Id,Arg>, shared: &Arc<SharedData>) -> Vec<(ExpandsTo, Vec<Id>)> {
let mut ivars_expansions = vec![];
for ivar in 0..original_pattern.first_zid_of_ivar.len() {
let arg_of_loc_ivar = &shared.arg_of_zid_node[original_pattern.first_zid_of_ivar[ivar]];
let locs: Vec<Id> = original_pattern.match_locations.iter()
.filter(|loc|
arg_of_loc[loc].shifted_id ==
arg_of_loc_ivar[loc].shifted_id).cloned().collect();
if locs.is_empty() { continue; }
ivars_expansions.push((ExpandsTo::IVar(ivar as i32), locs));
}
if original_pattern.first_zid_of_ivar.len() < shared.cfg.max_arity {
let ivar = original_pattern.first_zid_of_ivar.len();
let locs = original_pattern.match_locations.clone();
ivars_expansions.push((ExpandsTo::IVar(ivar as i32), locs));
}
ivars_expansions
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FinishedPattern {
pub pattern: Pattern,
pub utility: i32,
pub compressive_utility: i32,
pub util_calc: UtilityCalculation,
pub arity: usize,
pub usages: i32,
}
impl FinishedPattern {
fn new(pattern: Pattern, shared: &SharedData) -> Self {
let arity = pattern.first_zid_of_ivar.len();
let usages = pattern.match_locations.iter().map(|loc| shared.num_paths_to_node[usize::from(*loc)]).sum();
let compressive_utility = compressive_utility(&pattern,shared);
let noncompressive_utility = noncompressive_utility(pattern.body_utility, &shared.cfg);
let utility = noncompressive_utility + compressive_utility.util;
assert!(utility <= pattern.utility_upper_bound, "{} BUT utility is higher: {} (usages: {})", pattern.info(shared), utility, usages);
let mut res = FinishedPattern {
pattern,
utility,
compressive_utility: compressive_utility.util,
util_calc: compressive_utility,
arity,
usages,
};
if shared.cfg.utility_by_rewrite {
let rewritten: Vec<Expr> = rewrite_fast(&res, shared, "fake_inv");
res.compressive_utility = shared.init_cost - shared.root_idxs_of_task.iter().map(|root_idxs|
root_idxs.iter().map(|idx| rewritten[*idx].cost()).min().unwrap()
).sum::<i32>();
res.util_calc.util = res.compressive_utility;
res.utility = res.compressive_utility + noncompressive_utility;
}
res
}
pub fn to_expr(&self, shared: &SharedData) -> Expr {
self.pattern.to_expr(shared)
}
pub fn to_invention(&self, name: &str, shared: &SharedData) -> Invention {
Invention::new(self.to_expr(shared), self.arity, name)
}
pub fn info(&self, shared: &SharedData) -> String {
format!("{} -> finished: utility={}, compressive_utility={}, arity={}, usages={}",self.pattern.info(shared), self.utility, self.compressive_utility, self.arity, self.usages)
}
}
#[allow(clippy::type_complexity)]
fn get_zippers(
treenodes: &[Id],
cost_of_node_once: &[i32],
egraph: &mut EGraph,
) -> (FxHashMap<Zip, ZId>, Vec<Zip>, Vec<FxHashMap<Id,Arg>>, FxHashMap<Id,Vec<ZId>>, Vec<ZIdExtension>) {
let cache: &mut Option<RecVarModCache> = &mut Some(FxHashMap::default());
let mut zid_of_zip: FxHashMap<Zip, ZId> = Default::default();
let mut zip_of_zid: Vec<Zip> = Default::default();
let mut arg_of_zid_node: Vec<FxHashMap<Id,Arg>> = Default::default();
let mut zids_of_node: FxHashMap<Id,Vec<ZId>> = Default::default();
zid_of_zip.insert(vec![], EMPTY_ZID);
zip_of_zid.push(vec![]);
arg_of_zid_node.push(FxHashMap::default());
for treenode in treenodes.iter() {
assert!(egraph[*treenode].nodes.len() == 1);
let node = egraph[*treenode].nodes[0].clone();
let mut zids: Vec<ZId> = vec![EMPTY_ZID];
arg_of_zid_node[EMPTY_ZID].insert(*treenode,
Arg { shifted_id: *treenode, unshifted_id: *treenode, shift: 0, cost: cost_of_node_once[usize::from(*treenode)], expands_to: expands_to_of_node(&node) });
match node {
Lambda::IVar(_) => { panic!("attempted to abstract an IVar") }
Lambda::Var(_) | Lambda::Prim(_) | Lambda::Programs(_) => {},
Lambda::App([f,x]) => {
for f_zid in zids_of_node[&f].iter() {
let mut zip = zip_of_zid[*f_zid].clone();
zip.insert(0,ZNode::Func);
let zid = zid_of_zip.entry(zip.clone()).or_insert_with(|| {
let zid = zip_of_zid.len();
zip_of_zid.push(zip);
arg_of_zid_node.push(FxHashMap::default());
zid
});
zids.push(*zid);
let arg = arg_of_zid_node[*f_zid][&f].clone();
arg_of_zid_node[*zid].insert(*treenode, arg);
}
for x_zid in zids_of_node[&x].iter() {
let mut zip = zip_of_zid[*x_zid].clone();
zip.insert(0,ZNode::Arg);
let zid = zid_of_zip.entry(zip.clone()).or_insert_with(|| {
let zid = zip_of_zid.len();
zip_of_zid.push(zip);
arg_of_zid_node.push(FxHashMap::default());
zid
});
zids.push(*zid);
let arg = arg_of_zid_node[*x_zid][&x].clone();
arg_of_zid_node[*zid].insert(*treenode, arg);
}
},
Lambda::Lam([b]) => {
for b_zid in zids_of_node[&b].iter() {
let mut zip = zip_of_zid[*b_zid].clone();
zip.insert(0,ZNode::Body);
let zid = zid_of_zip.entry(zip.clone()).or_insert_with(|| {
let zid = zip_of_zid.len();
zip_of_zid.push(zip.clone());
arg_of_zid_node.push(FxHashMap::default());
zid
});
zids.push(*zid);
let mut arg: Arg = arg_of_zid_node[*b_zid][&b].clone();
if !egraph[arg.shifted_id].data.free_vars.is_empty() {
if egraph[arg.shifted_id].data.free_vars.contains(&0) {
let depth_root_to_arg = zip.iter().filter(|x| **x == ZNode::Body).count() as i32;
arg.shifted_id = insert_arg_ivars(arg.shifted_id, depth_root_to_arg-1, egraph).unwrap();
}
arg.shifted_id = shift(arg.shifted_id, -1, egraph, cache).unwrap();
arg.shift -= 1;
}
arg_of_zid_node[*zid].insert(*treenode, arg);
} },
}
zids_of_node.insert(*treenode, zids);
}
let extensions_of_zid = zip_of_zid.iter().map(|zip| {
let mut zip_body = zip.clone();
zip_body.push(ZNode::Body);
let mut zip_arg = zip.clone();
zip_arg.push(ZNode::Arg);
let mut zip_func = zip.clone();
zip_func.push(ZNode::Func);
ZIdExtension {
body: zid_of_zip.get(&zip_body).copied(),
arg: zid_of_zip.get(&zip_arg).copied(),
func: zid_of_zip.get(&zip_func).copied(),
}
}).collect();
(zid_of_zip,
zip_of_zid,
arg_of_zid_node,
zids_of_node,
extensions_of_zid)
}
#[derive(Debug, Clone)]
pub struct CompressionStepResult {
pub inv: Invention,
pub rewritten: Expr,
pub rewritten_dreamcoder: Vec<String>,
pub done: FinishedPattern,
pub expected_cost: i32,
pub final_cost: i32,
pub multiplier: f64,
pub multiplier_wrt_orig: f64,
pub uses: i32,
pub use_exprs: Vec<Expr>,
pub use_args: Vec<Vec<Expr>>,
pub dc_inv_str: String,
pub initial_cost: i32,
}
impl CompressionStepResult {
fn new(done: FinishedPattern, inv_name: &str, shared: &mut SharedData, past_invs: &[CompressionStepResult], prev_dc_inv_to_inv_strs: &[(String, String)]) -> Self {
let very_first_cost = if let Some(past_inv) = past_invs.first() { past_inv.initial_cost } else { shared.init_cost };
let inv = done.to_invention(inv_name, shared);
let rewritten = rewrite_fast(&done, shared, &inv.name);
let expected_cost = shared.init_cost - done.compressive_utility;
let final_cost = shared.root_idxs_of_task.iter().map(|root_idxs|
root_idxs.iter().map(|idx| rewritten[*idx].cost()).min().unwrap()
).sum::<i32>();
if expected_cost != final_cost {
println!("*** expected cost {} != final cost {}", expected_cost, final_cost);
}
let multiplier = shared.init_cost as f64 / final_cost as f64;
let multiplier_wrt_orig = very_first_cost as f64 / final_cost as f64;
let uses = done.usages;
let use_exprs: Vec<Expr> = done.pattern.match_locations.iter().map(|node| extract(*node, &shared.egraph)).collect();
let use_args: Vec<Vec<Expr>> = done.pattern.match_locations.iter().map(|node|
done.pattern.first_zid_of_ivar.iter().map(|zid|
extract(shared.arg_of_zid_node[*zid][node].shifted_id, &shared.egraph)
).collect()).collect();
let mut dreamcoder_translations: Vec<(String, String)> = past_invs.iter().map(|compression_step_result| (compression_step_result.inv.name.clone(), compression_step_result.dc_inv_str.clone())).collect();
dreamcoder_translations.extend(prev_dc_inv_to_inv_strs.iter().cloned());
let dc_inv_str: String = dc_inv_str(&inv, &dreamcoder_translations);
let rewritten_dreamcoder: Vec<String> = rewritten.iter().map(|p|{
let mut res = p.to_string();
for (prev_inv_name, prev_dc_inv_str) in prev_dc_inv_to_inv_strs {
res = replace_prim_with(&res, prev_inv_name, prev_dc_inv_str);
}
res = replace_prim_with(&res, inv_name, &dc_inv_str);
res = res.replace("(lam ","(lambda ");
res
}).collect();
CompressionStepResult { inv, rewritten: Expr::programs(rewritten), rewritten_dreamcoder, done, expected_cost, final_cost, multiplier, multiplier_wrt_orig, uses, use_exprs, use_args, dc_inv_str, initial_cost: shared.init_cost }
}
pub fn json(&self) -> serde_json::Value {
let use_exprs: Vec<String> = self.use_exprs.iter().map(|expr| expr.to_string()).collect();
let use_args: Vec<String> = self.use_args.iter().map(|args| format!("{} {}", self.inv.name, args.iter().map(|expr| expr.to_string()).collect::<Vec<String>>().join(" "))).collect();
let all_uses: Vec<serde_json::Value> = use_exprs.iter().zip(use_args.iter()).sorted().map(|(expr,args)| json!({args: expr})).collect();
json!({
"body": self.inv.body.to_string(),
"dreamcoder": self.dc_inv_str,
"arity": self.inv.arity,
"name": self.inv.name,
"rewritten": self.rewritten.split_programs().iter().map(|p| p.to_string()).collect::<Vec<String>>(),
"rewritten_dreamcoder": self.rewritten_dreamcoder,
"utility": self.done.utility,
"expected_cost": self.expected_cost,
"final_cost": self.final_cost,
"multiplier": self.multiplier,
"multiplier_wrt_orig": self.multiplier_wrt_orig,
"num_uses": self.uses,
"uses": all_uses,
})
}
}
impl fmt::Display for CompressionStepResult {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if self.expected_cost != self.final_cost {
write!(f,"[cost mismatch of {}] ", self.expected_cost - self.final_cost)?;
}
write!(f, "utility: {} | final_cost: {} | {:.2}x | uses: {} | body: {}",
self.done.utility, self.final_cost, self.multiplier, self.uses, self.inv)
}
}
fn utility_upper_bound(
match_locations: &[Id],
body_utility_lower_bound: i32,
cost_of_node_all: &[i32],
num_paths_to_node: &[i32],
cfg: &CompressionStepConfig,
) -> i32 {
compressive_utility_upper_bound(match_locations, cost_of_node_all, num_paths_to_node)
+ noncompressive_utility_upper_bound(body_utility_lower_bound, cfg)
}
fn noncompressive_utility(
body_utility: i32,
cfg: &CompressionStepConfig,
) -> i32 {
if cfg.no_other_util { return 0; }
- body_utility
}
fn compressive_utility_upper_bound(
match_locations: &[Id],
cost_of_node_all: &[i32],
num_paths_to_node: &[i32],
) -> i32 {
match_locations.iter().map(|node|
cost_of_node_all[usize::from(*node)]
- num_paths_to_node[usize::from(*node)] * COST_TERMINAL).sum::<i32>()
}
fn noncompressive_utility_upper_bound(
body_utility_lower_bound: i32,
cfg: &CompressionStepConfig,
) -> i32 {
if cfg.no_other_util { return 0; }
- body_utility_lower_bound
}
fn compressive_utility(pattern: &Pattern, shared: &SharedData) -> UtilityCalculation {
let utility_of_loc_once: Vec<i32> = get_utility_of_loc_once(pattern, shared);
let (cumulative_utility_of_node, corrected_utils) = bottom_up_utility_correction(pattern,shared,&utility_of_loc_once);
let compressive_utility: i32 = shared.init_cost - shared.root_idxs_of_task.iter().map(|root_idxs|
root_idxs.iter().map(|idx| shared.init_cost_by_root_idx[*idx] - cumulative_utility_of_node[usize::from(shared.roots[*idx])]).min().unwrap()
).sum::<i32>();
UtilityCalculation { util: compressive_utility, corrected_utils }
}
fn get_utility_of_loc_once(pattern: &Pattern, shared: &SharedData) -> Vec<i32> {
let app_penalty = - (COST_TERMINAL + COST_NONTERMINAL * pattern.first_zid_of_ivar.len() as i32);
let ivar_multiuses: Vec<(usize,i32)> = pattern.arg_choices.iter().map(|labelled|labelled.ivar).counts()
.iter().filter_map(|(ivar,count)| if *count > 1 { Some((*ivar, (*count-1) as i32)) } else { None }).collect();
pattern.match_locations.iter().map(|loc| {
for (_ivar,zid) in pattern.first_zid_of_ivar.iter().enumerate() {
let shifted_arg = shared.arg_of_zid_node[*zid][loc].shifted_id;
if !shared.egraph[shifted_arg].data.free_ivars.is_empty() {
return 0; }
}
let base_utility = pattern.body_utility + app_penalty;
let multiuse_utility = ivar_multiuses.iter().map(|(ivar,count)|
count * shared.arg_of_zid_node[pattern.first_zid_of_ivar[*ivar]][loc].cost
).sum::<i32>();
base_utility + multiuse_utility
}).collect()
}
fn bottom_up_utility_correction(pattern: &Pattern, shared:&SharedData, utility_of_loc_once: &[i32]) -> (Vec<i32>,FxHashMap<Id,bool>) {
let mut cumulative_utility_of_node: Vec<i32> = vec![0; shared.treenodes.len()];
let mut corrected_utils: FxHashMap<Id,bool> = Default::default();
for node in shared.treenodes.iter() {
let utility_without_rewrite: i32 = match &shared.node_of_id[usize::from(*node)] {
Lambda::Lam([b]) => cumulative_utility_of_node[usize::from(*b)],
Lambda::App([f,x]) => cumulative_utility_of_node[usize::from(*f)] + cumulative_utility_of_node[usize::from(*x)],
Lambda::Prim(_) | Lambda::Var(_) => 0,
Lambda::IVar(_) | Lambda::Programs(_) => unreachable!(),
};
assert!(utility_without_rewrite >= 0);
if let Ok(idx) = pattern.match_locations.binary_search(node) {
let utility_of_args: i32 = pattern.first_zid_of_ivar.iter()
.map(|zid| cumulative_utility_of_node[usize::from(shared.arg_of_zid_node[*zid][node].unshifted_id)])
.sum();
let utility_with_rewrite = utility_of_args + utility_of_loc_once[idx];
let chose_to_rewrite = utility_with_rewrite > utility_without_rewrite;
cumulative_utility_of_node[usize::from(*node)] = std::cmp::max(utility_with_rewrite, utility_without_rewrite);
corrected_utils.insert(*node,chose_to_rewrite);
} else if utility_without_rewrite != 0 {
cumulative_utility_of_node[usize::from(*node)] = utility_without_rewrite;
}
}
(cumulative_utility_of_node,corrected_utils)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UtilityCalculation {
pub util: i32,
pub corrected_utils: FxHashMap<Id,bool>, }
pub fn compression(
train_programs_expr: &Expr,
test_programs_expr: &Option<Expr>,
iterations: usize,
cfg: &CompressionStepConfig,
tasks: &[String],
prev_dc_inv_to_inv_strs: &[(String, String)],
) -> Vec<CompressionStepResult> {
let num_prior_inventions = prev_dc_inv_to_inv_strs.len();
let mut rewritten: Expr = train_programs_expr.clone();
let mut step_results: Vec<CompressionStepResult> = Default::default();
let tstart = std::time::Instant::now();
for i in 0..iterations {
println!("{}",format!("\n=======Iteration {}=======",i).blue().bold());
let inv_name = format!("fn_{}", num_prior_inventions + step_results.len());
let res: Vec<CompressionStepResult> = compression_step(
&rewritten,
&inv_name,
cfg,
&step_results,
tasks,
prev_dc_inv_to_inv_strs);
if !res.is_empty() {
let res: CompressionStepResult = res[0].clone();
rewritten = res.rewritten.clone();
println!("Chose Invention {}: {}", res.inv.name, res);
step_results.push(res);
} else {
println!("No inventions found at iteration {}",i);
break;
}
}
println!("{}","\n=======Compression Summary=======".blue().bold());
println!("Found {} inventions", step_results.len());
println!("Cost Improvement: ({:.2}x better) {} -> {}", compression_factor(train_programs_expr,&rewritten), train_programs_expr.cost(), rewritten.cost());
for res in step_results.iter() {
println!("{} ({:.2}x wrt orig): {}" , res.inv.name.clone().blue(), compression_factor(train_programs_expr, &res.rewritten), res);
}
println!("Time: {}ms", tstart.elapsed().as_millis());
if cfg.follow_track && !(
cfg.no_opt_free_vars
&& cfg.no_opt_single_task
&& cfg.no_opt_upper_bound
&& cfg.no_opt_force_multiuse
&& cfg.no_opt_useless_abstract
&& cfg.no_opt_arity_zero)
{
println!("{} you often want to run --follow-track with --no-opt otherwise your target may get pruned", "[WARNING]".yellow());
}
if let Some(e) = test_programs_expr {
println!("Test set compression with all inventions applied: {}", compression_factor(e, &rewrite_with_inventions(e.clone(), &step_results.iter().map(|r| r.inv.clone()).collect::<Vec<Invention>>())));
}
step_results
}
pub fn compression_step(
programs_expr: &Expr,
new_inv_name: &str, cfg: &CompressionStepConfig,
past_invs: &[CompressionStepResult], task_name_of_root_idx: &[String],
prev_dc_inv_to_inv_strs: &[(String, String)],
) -> Vec<CompressionStepResult> {
let tstart_total = std::time::Instant::now();
let tstart_prep = std::time::Instant::now();
let mut tstart = std::time::Instant::now();
let mut egraph: EGraph = Default::default();
let programs_node = egraph.add_expr(programs_expr.into());
egraph.rebuild();
println!("set up egraph: {:?}ms", tstart.elapsed().as_millis());
tstart = std::time::Instant::now();
let roots: Vec<Id> = egraph[programs_node].nodes[0].children().to_vec();
let mut treenodes: Vec<Id> = topological_ordering(programs_node,&egraph);
assert!(treenodes.iter().enumerate().all(|(i,node)| i == usize::from(*node)));
let node_of_id: Vec<Lambda> = treenodes.iter().map(|node| egraph[*node].nodes[0].clone()).collect();
treenodes.retain(|id| *id != programs_node);
println!("got roots, treenodes, and cloned egraph contents: {:?}ms", tstart.elapsed().as_millis());
tstart = std::time::Instant::now();
let (num_paths_to_node, num_paths_to_node_by_root_idx) : (Vec<i32>, Vec<Vec<i32>>) = num_paths_to_node(&roots, &treenodes, &egraph);
println!("num_paths_to_node(): {:?}ms", tstart.elapsed().as_millis());
tstart = std::time::Instant::now();
let mut task_name_of_task: Vec<String> = vec![];
let mut task_of_root_idx: Vec<usize> = vec![];
let mut root_idxs_of_task: Vec<Vec<usize>> = vec![];
for (root_idx,task_name) in task_name_of_root_idx.iter().enumerate() {
let task = task_name_of_task.iter().position(|name| name == task_name)
.unwrap_or_else(||{
task_name_of_task.push(task_name.clone());
root_idxs_of_task.push(vec![]);
task_name_of_task.len() - 1
});
task_of_root_idx.push(task);
root_idxs_of_task[task].push(root_idx);
}
let tasks_of_node: Vec<FxHashSet<usize>> = associate_tasks(programs_node, &egraph, &treenodes, &task_of_root_idx);
let init_cost_by_root_idx: Vec<i32> = roots.iter().map(|id| egraph[*id].data.inventionless_cost).collect();
let init_cost: i32 = root_idxs_of_task.iter().map(|root_idxs|
root_idxs.iter().map(|idx| init_cost_by_root_idx[*idx]).min().unwrap()
).sum();
println!("associate_tasks() and other task stuff: {:?}ms", tstart.elapsed().as_millis());
println!("num unique tasks: {}", task_name_of_task.len());
println!("num unique programs: {}", roots.len());
tstart = std::time::Instant::now();
let cost_of_node_once: Vec<i32> = treenodes.iter().map(|node| egraph[*node].data.inventionless_cost).collect();
let cost_of_node_all: Vec<i32> = treenodes.iter().map(|node| cost_of_node_once[usize::from(*node)] * num_paths_to_node[usize::from(*node)]).collect();
let free_vars_of_node: Vec<FxHashSet<i32>> = treenodes.iter().map(|node| egraph[*node].data.free_vars.clone()).collect();
println!("cost_of_node structs: {:?}ms", tstart.elapsed().as_millis());
tstart = std::time::Instant::now();
let (zid_of_zip,
zip_of_zid,
arg_of_zid_node,
zids_of_node,
extensions_of_zid) = get_zippers(&treenodes, &cost_of_node_once, &mut egraph);
println!("get_zippers(): {:?}ms", tstart.elapsed().as_millis());
tstart = std::time::Instant::now();
println!("{} zips", zip_of_zid.len());
println!("arg_of_zid_node size: {}", arg_of_zid_node.len());
let tracking: Option<Tracking> = cfg.track.as_ref().map(|s|{
let expr: Expr = s.parse().unwrap();
let zids_of_ivar = zids_of_ivar_of_expr(&expr, &zid_of_zip);
Tracking { expr, zids_of_ivar }
});
println!("Tracking setup: {:?}ms", tstart.elapsed().as_millis());
let mut stats: Stats = Default::default();
tstart = std::time::Instant::now();
let mut donelist: Vec<FinishedPattern> = Default::default();
if !cfg.no_opt_arity_zero {
for node in treenodes.iter() {
if !cfg.no_opt_free_vars && !egraph[*node].data.free_vars.is_empty() {
if !cfg.no_stats { stats.free_vars_fired += 1; };
continue;
}
if !cfg.no_opt_single_task && tasks_of_node[usize::from(*node)].len() < 2 {
if !cfg.no_stats { stats.single_task_fired += 1; };
continue;
}
let match_locations = vec![*node];
let body_utility = cost_of_node_once[usize::from(*node)];
let compressive_utility: i32 = init_cost - root_idxs_of_task.iter().map(|root_idxs|
root_idxs.iter().map(|idx| init_cost_by_root_idx[*idx] - num_paths_to_node_by_root_idx[*idx][usize::from(*node)] * (cost_of_node_once[usize::from(*node)] - COST_TERMINAL))
.min().unwrap()
).sum::<i32>();
let utility = compressive_utility + noncompressive_utility(body_utility, cfg);
if utility <= 0 { continue; }
let pattern = Pattern {
holes: vec![],
arg_choices: vec![],
first_zid_of_ivar: vec![],
match_locations,
utility_upper_bound: utility,
body_utility,
tracked: false,
};
let finished_pattern = FinishedPattern {
pattern,
utility,
compressive_utility,
util_calc: UtilityCalculation { util: compressive_utility, corrected_utils: Default::default()},
arity: 0,
usages: num_paths_to_node[usize::from(*node)]
};
donelist.push(finished_pattern);
}
}
println!("arity 0: {:?}ms", tstart.elapsed().as_millis());
tstart = std::time::Instant::now();
println!("got {} arity zero inventions", donelist.len());
let crit = CriticalMultithreadData::new(donelist, &treenodes, &cost_of_node_all, &num_paths_to_node, &egraph, cfg);
let shared = Arc::new(SharedData {
crit: Mutex::new(crit),
arg_of_zid_node,
treenodes: treenodes.clone(),
node_of_id,
programs_node,
roots,
zids_of_node,
zip_of_zid,
zid_of_zip,
extensions_of_zid,
egraph,
num_paths_to_node,
num_paths_to_node_by_root_idx,
tasks_of_node,
task_name_of_task,
task_of_root_idx,
root_idxs_of_task,
cost_of_node_once,
cost_of_node_all,
free_vars_of_node,
init_cost,
init_cost_by_root_idx,
stats: Mutex::new(stats),
cfg: cfg.clone(),
tracking,
});
println!("built SharedData: {:?}ms", tstart.elapsed().as_millis());
tstart = std::time::Instant::now();
if cfg.verbose_best {
let mut crit = shared.crit.lock();
if !crit.deref_mut().donelist.is_empty() {
let best_util = crit.deref_mut().donelist.first().unwrap().utility;
let best_expr: String = crit.deref_mut().donelist.first().unwrap().info(&shared);
println!("{} @ step=0 util={} for {}", "[new best utility]".blue(), best_util, best_expr);
}
}
println!("TOTAL PREP: {:?}ms", tstart_prep.elapsed().as_millis());
println!("running pattern search...");
if cfg.threads == 1 {
stitch_search(Arc::clone(&shared));
} else {
let mut handles = vec![];
for _ in 0..cfg.threads {
let shared = Arc::clone(&shared);
handles.push(thread::spawn(move || {
stitch_search(shared);
}));
}
for handle in handles {
handle.join().unwrap();
}
}
println!("TOTAL SEARCH: {:?}ms", tstart.elapsed().as_millis());
println!("TOTAL PREP + SEARCH: {:?}ms", tstart_total.elapsed().as_millis());
tstart = std::time::Instant::now();
let mut shared: SharedData = Arc::try_unwrap(shared).unwrap();
shared.crit.lock().deref_mut().update(cfg);
println!("{:?}", shared.stats.lock().deref_mut());
assert!(shared.crit.lock().deref_mut().worklist.is_empty());
let donelist: Vec<FinishedPattern> = shared.crit.lock().deref_mut().donelist.clone();
if cfg.dreamcoder_comparison {
println!("Timing point 1 (from the start of compression_step to final donelist): {:?}ms", tstart_total.elapsed().as_millis());
println!("Timing Comparison Point A (search) (millis): {}", tstart_total.elapsed().as_millis());
let tstart_rewrite = std::time::Instant::now();
rewrite_fast(&donelist[0], &shared, new_inv_name);
println!("Timing point 2 (rewriting the candidate): {:?}ms", tstart_rewrite.elapsed().as_millis());
println!("Timing Comparison Point B (search+rewrite) (millis): {}", tstart_total.elapsed().as_millis());
}
let mut results: Vec<CompressionStepResult> = vec![];
println!("Cost before: {}", shared.init_cost);
for (i,done) in donelist.iter().enumerate() {
let res = CompressionStepResult::new(done.clone(), new_inv_name, &mut shared, past_invs, prev_dc_inv_to_inv_strs);
println!("{}: {}", i, res);
if cfg.show_rewritten {
println!("rewritten:\n{}", res.rewritten.split_programs().iter().map(|p|p.to_string()).collect::<Vec<_>>().join("\n"));
}
results.push(res);
}
println!("post stuff: {:?}ms", tstart.elapsed().as_millis());
results
}