stitch_core/abstraction_learning/
rewriting.rs

1use crate::abstraction_learning::*;
2use crate::abstraction_learning::egraphs::EGraph;
3use lambdas::*;
4use rustc_hash::{FxHashMap,FxHashSet};
5use compression::*;
6
7
8
9/// convert an egraph Id to an Expr. Assumes one node per class (just picks the first node). Note
10/// that this could cause an infinite loop if the egraph didnt just have a single node in a class
11/// and instead the first node had a self loop.
12pub fn extract(eclass: Id, egraph: &EGraph) -> Expr {
13    debug_assert!(egraph[eclass].nodes.len() == 1);
14    match &egraph[eclass].nodes[0] {
15        Lambda::Prim(p) => Expr::prim(*p),
16        Lambda::Var(i) => Expr::var(*i),
17        Lambda::IVar(i) => Expr::ivar(*i),
18        Lambda::App([f,x]) => Expr::app(extract(*f,egraph), extract(*x,egraph)),
19        Lambda::Lam([b]) => Expr::lam(extract(*b,egraph)),
20        Lambda::Programs(roots) => Expr::programs(roots.iter().map(|r| extract(*r,egraph)).collect()),
21    }
22}
23
24/// like extract() but works on nodes
25pub fn extract_enode(enode: &Lambda, egraph: &EGraph) -> Expr {
26    match enode {
27        Lambda::Prim(p) => Expr::prim(*p),
28        Lambda::Var(i) => Expr::var(*i),
29        Lambda::IVar(i) => Expr::ivar(*i),
30        Lambda::App([f,x]) => Expr::app(extract(*f,egraph),extract(*x,egraph)),
31        Lambda::Lam([b]) => Expr::lam(extract(*b,egraph)),
32        _ => {panic!("not rendered")},
33    }
34}
35
36/// a rule for determining when to shift and by how much.
37/// if anything points above the `depth_cutoff` (absolute depth
38/// of lambdas from the top of the program) it gets shifted by `shift`. If
39/// more than one ShiftRule applies then more then one shift will happen.
40struct ShiftRule {
41    depth_cutoff: i32, 
42    shift: i32,
43}
44
45pub fn rewrite_fast(
46    pattern: &FinishedPattern,
47    shared: &SharedData,
48    inv_name: &str,
49) -> Vec<Expr>
50{
51    // println!("rewriting with {}", pattern.info(&shared));
52    fn helper(
53        pattern: &FinishedPattern,
54        shared: &SharedData,
55        unshifted_id: Id,
56        total_depth: i32, // depth from the very root of the program down
57        shift_rules: &mut Vec<ShiftRule>,
58        inv_name: &str,
59        refinements: Option<(&Vec<Id>,i32)>
60    ) -> Expr
61    {
62        // we search using the the *unshifted* one since its an original program tree node
63        if pattern.pattern.match_locations.binary_search(&unshifted_id).is_ok() // if the pattern matches here
64           && (!pattern.util_calc.corrected_utils.contains_key(&unshifted_id) // and either we have no conflict (ie corrected_utils doesnt have an entry)
65             || pattern.util_calc.corrected_utils[&unshifted_id]) // or we have a conflict but we choose to accept it (which is contextless in this top down approach so its the right move)
66           && refinements.is_none() // AND we can't currently be in a refinement where rewriting is forbidden
67        //    && !pattern.pattern.first_zid_of_ivar.iter().any(|zid| // and there are no negative vars anywhere in the arguments
68        //         shared.egraph[shared.arg_of_zid_node[*zid][&unshifted_id].id].data.free_vars.iter().any(|var| *var < 0))
69        {
70            // println!("inv applies at unshifted={} with shift={}", extract(unshifted_id,&shared.egraph), shift);
71            let mut expr = Expr::prim(inv_name.into());
72            // wrap the prim in all the Apps to args
73            for (_ivar,zid) in pattern.pattern.first_zid_of_ivar.iter().enumerate() {
74                let arg: &Arg = &shared.arg_of_zid_node[*zid][&unshifted_id];
75
76
77                // assert!(shared.egraph[arg.shifted_id].data.free_vars.iter().all(|v| *v >= 0));
78                // assert!(arg.id == egraphs::shift(arg.unshifted_id, arg.shift, &shared.egraph, None).unwrap());
79
80                if arg.shift != 0 {
81                    shift_rules.push(ShiftRule{depth_cutoff: total_depth, shift: arg.shift});
82                }
83                let rewritten_arg = helper(pattern, shared, arg.unshifted_id, total_depth, shift_rules, inv_name, None);
84                if arg.shift != 0 {
85                    shift_rules.pop(); // pop the rule back off after
86                }
87                expr = Expr::app(expr, rewritten_arg);
88            }
89            return expr
90        }
91        // println!("descending: {}", extract(unshifted_id,&shared.egraph));
92
93        if let Some((refinements,arg_depth)) = refinements.as_ref() {
94            if let Some(idx) = refinements.iter().position(|r| *r == unshifted_id) {
95                // println!("found refinement!!!");
96                // todo should this be `idx` or `refinements.len()-1-idx`?
97                return Expr::var(total_depth - arg_depth + idx as i32); // if we didnt pass thru any lams on the way this would just be $0 and thus refer to the Expr::lam() wrapping our helper() call
98            }
99        }
100
101
102        match &shared.node_of_id[usize::from(unshifted_id)] {
103            Lambda::Prim(p) => Expr::prim(*p),
104            Lambda::Var(i) => {
105                let mut j = *i;
106                for rule in shift_rules.iter() {
107                    // take "i" steps upward from current depth and see if you meet or pass the cutoff.
108                    // exactly equaling the cutoff counts as needing shifting.
109                    if total_depth - i <= rule.depth_cutoff {
110                        j += rule.shift;
111                    }
112                }
113                if let Some((refinements,arg_depth)) = refinements.as_ref() {
114                    // we're inside the *shifted arg* of a refinement so this var has already been shifted a bit btw
115                    // tho thats kinda irrelevant right here
116                    if j >= *arg_depth {
117                        // j is pointing above the invention so we need to upshift it a bit to account for the new lambdas we added
118                        j += refinements.len() as i32;
119                    }
120                }
121                assert!(j >= 0, "{}", pattern.to_expr(shared));
122                Expr::var(j)
123            }, // we extract from the *shifted* one since thats the real one
124            Lambda::App([unshifted_f,unshifted_x]) => {
125                Expr::app(
126                    helper(pattern, shared, *unshifted_f, total_depth, shift_rules, inv_name, refinements),
127                    helper(pattern, shared, *unshifted_x, total_depth, shift_rules, inv_name, refinements),
128                )
129            },
130            Lambda::Lam([unshifted_b]) => {
131                Expr::lam(helper(pattern, shared, *unshifted_b, total_depth + 1, shift_rules, inv_name, refinements))
132            },
133            Lambda::IVar(_) => {
134                panic!("attempted to rewrite with an ivar");
135            },
136            _ => unreachable!(),
137        }
138    }
139
140    let shift_rules = &mut vec![];
141    let rewritten_exprs: Vec<Expr> = shared.roots.iter().map(|root| {
142        helper(pattern, shared, *root, 0, shift_rules, inv_name, None)
143    }).collect();
144
145    if !shared.cfg.no_mismatch_check && !shared.cfg.utility_by_rewrite {
146        assert_eq!(
147            shared.root_idxs_of_task.iter().map(|root_idxs|
148                root_idxs.iter().map(|idx| rewritten_exprs[*idx].cost()).min().unwrap()
149            ).sum::<i32>(),
150            shared.init_cost - pattern.util_calc.util,
151            "\n{}\n", pattern.info(shared)
152        );
153    }
154
155    rewritten_exprs
156}
157
158
159
160
161
162/// These are like Inventions but with a pointer to the body instead of an Expr
163#[derive(Debug, Clone, Eq, PartialEq, Hash, PartialOrd, Ord)]
164struct PtrInvention {
165    pub body:Id, // this will be a subtree which can have IVars
166    pub arity: usize, // also equal to max ivar in subtree + 1
167    pub name: String
168}
169impl PtrInvention {
170    pub fn new(body:Id, arity: usize, name: String) -> Self {
171        PtrInvention {
172            body,
173            arity,
174            name
175        }
176    }
177}
178
179/// Same as `rewrite_with_invention` but for multiple inventions, rewriting with one after another in order, compounding on each other
180pub fn rewrite_with_inventions(
181    e: Expr,
182    invs: &[Invention]
183) -> Expr {
184    let mut egraph = EGraph::default();
185    let root = egraph.add_expr(&e.into());
186    rewrite_with_inventions_egraph(root, invs, &mut egraph)
187}
188
189/// Rewrite `root` using an invention `inv`. This will use inventions everywhere
190/// as long as it decreases the cost. It will account for the fact that using an invention
191/// in a child could prevent the use of the invention in the parent - it will always do whatever
192/// gives the lowest cost.
193/// 
194/// For the `EGraph` argument here you can either pass in a fresh egraph constructed by `let mut egraph = EGraph::new(); egraph.add_expr(expr.into())`
195/// or if you make repeated calls to this function feel free to pass in the same egraph over and over. It doesn't matter what is in the EGraph already.
196
197pub fn rewrite_with_invention(
198    e: Expr,
199    inv: &Invention,
200) -> Expr {
201    let mut egraph = EGraph::default();
202    let root = egraph.add_expr(&e.into());
203    rewrite_with_invention_egraph(root, inv, &mut egraph)
204}
205
206/// Same as `rewrite_with_invention_egraph` but for multiple inventions, rewriting with one after another in order, compounding on each other
207pub fn rewrite_with_inventions_egraph(
208    root: Id,
209    invs: &[Invention],
210    egraph: &mut EGraph,
211) -> Expr {
212    let mut root = root;
213    for inv in invs.iter() {
214        let expr = rewrite_with_invention_egraph(root, inv, egraph);
215        root = egraph.add_expr(&expr.into());
216    }
217    extract(root,egraph)
218}
219
220/// Same as `rewrite_with_invention` but operates on an egraph instead of an Expr.
221/// 
222/// For the `EGraph` argument here you can either pass in a fresh egraph constructed by `let mut egraph = EGraph::new(); egraph.add_expr(expr.into())`
223/// or if you make repeated calls to this function feel free to pass in the same egraph over and over. It doesn't matter what is in the EGraph already
224/// as long as `root` is in it.
225pub fn rewrite_with_invention_egraph(
226    root: Id,
227    inv: &Invention,
228    egraph: &mut EGraph,
229) -> Expr {
230    let inv: PtrInvention = PtrInvention::new(egraph.add_expr(&inv.body.clone().into()), inv.arity, inv.name.clone());
231
232    let treenodes = topological_ordering(root, egraph);
233
234    assert!(!treenodes.iter().any(|n| egraph[*n].nodes[0] == Lambda::Prim(Symbol::from(&inv.name))),
235        "Invention {} already in tree", inv.name);
236
237    let mut nodecost_of_treenode: FxHashMap<Id,NodeCost> = Default::default();
238    
239    for treenode in treenodes.iter() {
240        // println!("processing id={}: {}", treenode, extract(*treenode, egraph) );
241
242        // clone to appease the borrow checker
243        let node = egraph[*treenode].nodes[0].clone();
244
245        let mut nodecost = NodeCost::new(egraph[*treenode].data.inventionless_cost);
246
247        // trying to use the invs at this node
248        if let Some(args) = match_expr_with_inv(*treenode, &inv, &mut nodecost_of_treenode, egraph) {
249            let cost: i32 =
250                COST_TERMINAL // the new primitive for this invention
251                + COST_NONTERMINAL * inv.arity as i32 // the chain of app()s needed to apply the new primitive
252                + args.iter()
253                    .map(|id| nodecost_of_treenode[id]
254                        .cost_under_inv(&inv)) // cost under ANY of the invs since we allow multiple to be used!
255                    .sum::<i32>(); // sum costs of actual args
256                    nodecost.new_cost_under_inv(inv.clone(), cost, Some(args));
257        }
258
259
260        // inventions based on specific node type
261        match node {
262            Lambda::IVar(_) => { unreachable!() }
263            Lambda::Var(_) => {},
264            Lambda::Prim(_) => {},
265            Lambda::App([f,x]) => {
266                let f_nodecost = &nodecost_of_treenode[&f];
267                let x_nodecost = &nodecost_of_treenode[&x];
268                                
269                // costs with inventions as 1 + fcost + xcost. Use inventionless cost as a default.
270                // if either fcost or xcost is None (ie infinite)
271                let fcost = f_nodecost.cost_under_inv(&inv);
272                let xcost = x_nodecost.cost_under_inv(&inv);
273                let cost = COST_NONTERMINAL+fcost+xcost;
274                nodecost.new_cost_under_inv(inv.clone(), cost, None);
275            }
276            Lambda::Lam([b]) => {
277                // just map +1 over the costs
278                let b_nodecost = &nodecost_of_treenode[&b];
279                let bcost = b_nodecost.cost_under_inv(&inv);
280                nodecost.new_cost_under_inv(inv.clone(), bcost + COST_NONTERMINAL, None);
281            }
282            Lambda::Programs(roots) => {
283                // no filtering for 2+ uses because we're just doing rewriting here
284                let cost = roots.iter().map(|root| {
285                        nodecost_of_treenode[root].cost_under_inv(&inv)
286                    }).sum();
287                    nodecost.new_cost_under_inv(inv.clone(), cost, None);
288            }
289        }
290
291        nodecost_of_treenode.insert(*treenode, nodecost);
292    }
293
294    // Now that we've calculated all the costs, we can extract the cheapest one
295    extract_from_nodecosts(root, &inv, &nodecost_of_treenode, egraph)
296}
297
298fn extract_from_nodecosts(
299    root: Id,
300    inv: &PtrInvention,
301    nodecost_of_treenode: &FxHashMap<Id,NodeCost>,
302    egraph: &EGraph,
303) -> Expr {
304
305    let target_cost = nodecost_of_treenode[&root].cost_under_inv(inv);
306
307    if let Some((inv,_cost,args)) = nodecost_of_treenode[&root].top_invention() {
308        if let Some(args) = args {
309            // invention was used here
310            let mut expr = Expr::prim(inv.name.clone().into());
311            // wrap the new primitive in app() calls. Note that you pass in the $0 args LAST given how appapplamlam works
312            for arg in args.iter() {
313                let arg_expr = extract_from_nodecosts(*arg, &inv, nodecost_of_treenode, egraph);
314                expr = Expr::app(expr,arg_expr);
315            }
316            assert_eq!(target_cost,expr.cost());
317            expr
318        } else {
319            // inventions were used in our children
320            let expr: Expr = match &egraph[root].nodes[0] {
321                Lambda::Prim(_) | Lambda::Var(_) | Lambda::IVar(_) => {unreachable!()},
322                Lambda::App([f,x]) => {
323                    let f_expr = extract_from_nodecosts(*f, &inv, nodecost_of_treenode, egraph);
324                    let x_expr = extract_from_nodecosts(*x, &inv, nodecost_of_treenode, egraph);
325                    Expr::app(f_expr,x_expr)
326                },
327                Lambda::Lam([b]) => {
328                    let b_expr = extract_from_nodecosts(*b, &inv, nodecost_of_treenode, egraph);
329                    Expr::lam(b_expr)
330                }
331                Lambda::Programs(roots) => {
332                    let root_exprs: Vec<Expr> = roots.iter()
333                        .map(|r| extract_from_nodecosts(*r, &inv, nodecost_of_treenode, egraph))
334                        .collect();
335                    Expr::programs(root_exprs)
336                }
337            };
338            assert_eq!(target_cost,expr.cost());
339            expr
340        }
341    } else {
342        // no invention was useful, just return original tree
343        let expr =  extract(root, egraph);
344        assert_eq!(target_cost,expr.cost());
345        expr
346    }
347}
348
349/// There will be one of these structs associated with each node, and it keeps
350/// track of the best inventions for that node, their costs, and their arguments.
351#[derive(Debug,Clone)]
352struct NodeCost {
353    inventionless_cost: i32,
354    inventionful_cost: FxHashMap<PtrInvention, (i32,Option<Vec<Id>>)>, // i32 = cost; and Some(args) gives the arguments if the invention is used at this node
355}
356
357impl NodeCost {
358    fn new(inventionless_cost: i32) -> Self {
359        Self {
360            inventionless_cost,
361            inventionful_cost: FxHashMap::default()
362        }
363    }
364    /// cost under an invention if it's useful for this node, else inventionless cost
365    fn cost_under_inv(&self, inv: &PtrInvention) -> i32 {
366        self.inventionful_cost.get(inv).map(|x|x.0).unwrap_or(self.inventionless_cost)
367    }
368    /// improve the cost using a new invention, or do nothing if we've already seen
369    /// a better cost for this invention. Also skip if inventionless cost is better.
370    fn new_cost_under_inv(&mut self, inv: PtrInvention, cost:i32, args: Option<Vec<Id>>) {
371        if cost < self.inventionless_cost
372                && (!self.inventionful_cost.contains_key(&inv) || cost < self.inventionful_cost[&inv].0)
373        {
374            self.inventionful_cost.insert(inv, (cost,args));
375        }
376    }
377    /// Get the top inventions in decreasing order of cost
378    #[allow(dead_code)] // todo at some point add tests for this
379    fn top_inventions(&self) -> Vec<PtrInvention> {
380        let mut top_inventions: Vec<PtrInvention> = self.inventionful_cost.keys().cloned().collect();
381        top_inventions.sort_by(|a,b| self.inventionful_cost[a].0.cmp(&self.inventionful_cost[b].0));
382        top_inventions
383    }
384    /// Get the top inventions in decreasing order of cost
385    fn top_invention(&self) -> Option<(PtrInvention,i32,Option<Vec<Id>>)> {
386        self.inventionful_cost.iter().min_by_key(|(_k,v)| v.0).map(|(k,v)| (k.clone(),v.0,v.1.clone()))
387    }
388}
389
390
391fn match_expr_with_inv(
392    root: Id,
393    inv: &PtrInvention,
394    best_inventions_of_treenode: &mut FxHashMap<Id, NodeCost>,
395    egraph: &mut EGraph,
396) -> Option<Vec<Id>> {
397    let mut args: Vec<Option<Id>> = vec![None;inv.arity];
398    let threadables = threadables_of_inv(inv.clone(), egraph);
399    if match_expr_with_inv_rec(root, inv.body, 0, &mut args, &threadables, best_inventions_of_treenode, egraph) {
400        assert!(args.iter().all(|x| x.is_some()), "{:?}\n{}\n{}", args, extract(root,egraph), extract(inv.body,egraph)); // if any didnt unwrap() fine that would mean some variable wasnt used at all in the invention body
401        Some(args.iter().map(|arg| arg.unwrap()).collect()) 
402    } else {
403        None
404    }
405}
406
407fn match_expr_with_inv_rec(
408    root: Id,
409    inv: Id,
410    depth: i32,
411    args: &mut [Option<Id>],
412    threadables: &FxHashSet<Id>,
413    best_inventions_of_treenode: &mut FxHashMap<Id, NodeCost>,
414    egraph: &mut EGraph,
415) -> bool {
416    // println!("comparing:\n\t{}\n\t{}",
417    //     extract(root, egraph).to_string(),
418    //     extract(inv, egraph).to_string()
419    // );
420    // println!("processing:\n\troot:{}\n\tinv:{} ts:{}", extract(root,egraph), extract(inv,egraph), threadables.contains(&inv));
421    match (&egraph[root].nodes[0].clone(), &egraph[inv].nodes[0].clone()) { // clone for the borrow checker
422        (Lambda::Prim(p), Lambda::Prim(q)) => { p == q },
423        (Lambda::Var(i), Lambda::Var(j)) => { i == j },
424        (root_node, Lambda::App([g,y])) if threadables.contains(&inv) => {
425            // todo this whole section is a nightmare so make sure there arent bugs
426
427            // a thread site only applies when the set of internal pointers is the same
428            // as the thread site's set of pointers.
429            let internal_free_vars: FxHashSet<i32> = egraph[root].data.free_vars.iter().filter(|i| **i < depth).cloned().collect();
430            let num_to_thread = internal_free_vars.len() as i32;
431            if internal_free_vars == egraph[inv].data.free_vars {
432                // println!("threading");
433                // free vars match exactly so we could thread here note that if we match here than an inner thread site wont match.
434                // however, also note that there some chance a nonthreading approach could work too which is always simpler,
435                // for example when matching (#0 $0) against (inc $0) we can simply set #0=inc instead of #0=(lam (inc $0))
436                // lets clone our args and reset if this fails
437                if let Lambda::App([f,x]) = root_node {
438                    let cloned_args: Vec<_> = args.to_vec();
439                    if match_expr_with_inv_rec(*f, *g, depth, args, threadables, best_inventions_of_treenode, egraph)
440                    && match_expr_with_inv_rec(*x, *y, depth, args, threadables, best_inventions_of_treenode, egraph) {
441                        return true;
442                    }
443                    args.clone_from_slice(cloned_args.as_slice());
444                }
445
446                // Now lets build the desired argument out of `root` by basically
447                // following what bubbling up would normally do: downshifting for each
448                // internal lambda in the invention, except adding an extra lambda on top
449                // in the case of threaded variables
450                let mut arg = root;
451                for i in 0..depth {
452                    if egraph[inv].data.free_vars.contains(&i) {
453                        // protect $0 before continuing with the shift
454                        arg = egraph.add(Lambda::Lam([arg]));
455                    }
456                    arg = shift(arg, -1, egraph, &mut None).unwrap();
457                }
458
459                // now copy over the best_inventions
460                if !best_inventions_of_treenode.contains_key(&arg) {
461                    let mut cloned = best_inventions_of_treenode[&root].clone();
462                    cloned.inventionless_cost += COST_NONTERMINAL * num_to_thread;
463                    // we'll force this arg to not use a toplevel invention at the "lam" node hence the None
464                    cloned.inventionful_cost.iter_mut().for_each(|(_key, val)| {val.0 += COST_NONTERMINAL * num_to_thread; val.1 = None});
465                    best_inventions_of_treenode.insert(arg,cloned);
466                }
467
468                let ivar = *egraph[inv].data.free_ivars.iter().next().unwrap() as usize;
469
470                // now finally check that these results align
471                if let Some(v) = args[ivar] {
472                    arg == v // if #j was bound to some id `v` before, then `root` must be `v` for this to match
473                } else {
474                    args[ivar] = Some(arg);
475                    // println!("Assigned #{} = {}", ivar, extract(arg,egraph));
476                    true
477                }
478            } else {
479                // not threadable case 
480                if let Lambda::App([f,x]) = root_node {
481                    return match_expr_with_inv_rec(*f, *g, depth, args, threadables, best_inventions_of_treenode, egraph)
482                    && match_expr_with_inv_rec(*x, *y, depth, args, threadables, best_inventions_of_treenode, egraph)
483                }
484                false
485            }
486        },
487        (Lambda::App([f,x]), Lambda::App([g,y])) => {
488            // not threadable case 
489            match_expr_with_inv_rec(*f, *g, depth, args, threadables, best_inventions_of_treenode, egraph)
490            && match_expr_with_inv_rec(*x, *y, depth, args, threadables, best_inventions_of_treenode, egraph)
491        }
492        (Lambda::Lam([b]), Lambda::Lam([c])) => {
493            match_expr_with_inv_rec(*b, *c, depth+1, args, threadables, best_inventions_of_treenode, egraph)
494        },
495        (_, Lambda::IVar(j)) => {
496            // We need to bind #j to `root`
497            // First `root` needs to be downshifted by `depth`. There are 3 cases:
498            let shifted_root: Id = if egraph[root].data.free_vars.is_empty() {
499                // 1. `root` has no free variables so no shifting is needed
500                root
501            } else if egraph[root].data.free_vars.iter().min().unwrap() - depth >= 0 {
502                // 2. `root` has free variables but they all point outside the invention so are safe to decrement
503                
504                // copy the cost of the unshifted node to the shifted node (see PR#1 comments for why this is safe)
505                fn shift_and_fix(node: Id, depth: i32, best_inventions_of_treenode: &mut FxHashMap<Id,NodeCost>, egraph: &mut EGraph) -> Id {
506                    let shifted_node = shift(node, -depth, egraph, &mut None).unwrap();
507                    if best_inventions_of_treenode.contains_key(&shifted_node) {
508                        return shifted_node; // this has already been handled
509                    }
510                    let mut cloned = best_inventions_of_treenode[&node].clone();
511                    // adjust the args needed for the shifted node so that they are shifted too. Note that this
512                    // is only safe because we only ever use this for a single invention at a time so this hashtable
513                    // actually only has one invention in it. This will be adjusted to be way more clear in the use-conflicts
514                    // PR that hasnt yet been merged.
515                    // Note that you propagate down the same "depth" for the shift amount for all recursive calls, I think this is correct
516                    
517                    cloned.inventionful_cost.iter_mut().for_each(|(_key, val)| {
518                        if let Some(args) = &mut val.1 {
519                            args.iter_mut().for_each(|arg| *arg = shift_and_fix(*arg, depth, best_inventions_of_treenode, egraph));
520                        }
521                    });
522                    best_inventions_of_treenode.insert(shifted_node,cloned);
523                    shifted_node
524                }
525                shift_and_fix(root, depth, best_inventions_of_treenode, egraph)
526            } else {
527                return false // threading needed but this is not a thread site
528            };
529
530            if let Some(v) = args[*j as usize] {
531                shifted_root == v // if #j was bound to some id `v` before, then `root` must be `v` for this to match
532            } else {
533                args[*j as usize] = Some(shifted_root);
534                // println!("Assigned #{} = {}", j, extract(shifted_root,egraph));
535                true
536            }
537         },
538        _ => { false }
539    }
540}
541
542fn threadables_of_inv(inv: PtrInvention, egraph: &EGraph) -> FxHashSet<Id> {
543    // a threadable is a (app #i $j) or (app <threadable> $j)
544    // assert j > k sanity check
545    // println!("Invention: {}", inv.to_expr(egraph));
546    let mut threadables: FxHashSet<Id> = Default::default();
547    let nodes = topological_ordering(inv.body, egraph);
548    for node in nodes {
549        if let Lambda::App([f,x]) = egraph[node].nodes[0] {
550            if matches!(egraph[x].nodes[0], Lambda::Var(_))
551                    && (matches!(egraph[f].nodes[0], Lambda::IVar(_)) || threadables.contains(&f))
552            {
553                threadables.insert(node);
554                // println!("Identified threadable: {}", extract(node,egraph));
555            }
556        }
557    }
558    threadables
559}