1use crate::abstraction_learning::*;
2use crate::abstraction_learning::egraphs::EGraph;
3use lambdas::*;
4use rustc_hash::{FxHashMap,FxHashSet};
5use compression::*;
6
7
8
9pub fn extract(eclass: Id, egraph: &EGraph) -> Expr {
13 debug_assert!(egraph[eclass].nodes.len() == 1);
14 match &egraph[eclass].nodes[0] {
15 Lambda::Prim(p) => Expr::prim(*p),
16 Lambda::Var(i) => Expr::var(*i),
17 Lambda::IVar(i) => Expr::ivar(*i),
18 Lambda::App([f,x]) => Expr::app(extract(*f,egraph), extract(*x,egraph)),
19 Lambda::Lam([b]) => Expr::lam(extract(*b,egraph)),
20 Lambda::Programs(roots) => Expr::programs(roots.iter().map(|r| extract(*r,egraph)).collect()),
21 }
22}
23
24pub fn extract_enode(enode: &Lambda, egraph: &EGraph) -> Expr {
26 match enode {
27 Lambda::Prim(p) => Expr::prim(*p),
28 Lambda::Var(i) => Expr::var(*i),
29 Lambda::IVar(i) => Expr::ivar(*i),
30 Lambda::App([f,x]) => Expr::app(extract(*f,egraph),extract(*x,egraph)),
31 Lambda::Lam([b]) => Expr::lam(extract(*b,egraph)),
32 _ => {panic!("not rendered")},
33 }
34}
35
36struct ShiftRule {
41 depth_cutoff: i32,
42 shift: i32,
43}
44
45pub fn rewrite_fast(
46 pattern: &FinishedPattern,
47 shared: &SharedData,
48 inv_name: &str,
49) -> Vec<Expr>
50{
51 fn helper(
53 pattern: &FinishedPattern,
54 shared: &SharedData,
55 unshifted_id: Id,
56 total_depth: i32, shift_rules: &mut Vec<ShiftRule>,
58 inv_name: &str,
59 refinements: Option<(&Vec<Id>,i32)>
60 ) -> Expr
61 {
62 if pattern.pattern.match_locations.binary_search(&unshifted_id).is_ok() && (!pattern.util_calc.corrected_utils.contains_key(&unshifted_id) || pattern.util_calc.corrected_utils[&unshifted_id]) && refinements.is_none() {
70 let mut expr = Expr::prim(inv_name.into());
72 for (_ivar,zid) in pattern.pattern.first_zid_of_ivar.iter().enumerate() {
74 let arg: &Arg = &shared.arg_of_zid_node[*zid][&unshifted_id];
75
76
77 if arg.shift != 0 {
81 shift_rules.push(ShiftRule{depth_cutoff: total_depth, shift: arg.shift});
82 }
83 let rewritten_arg = helper(pattern, shared, arg.unshifted_id, total_depth, shift_rules, inv_name, None);
84 if arg.shift != 0 {
85 shift_rules.pop(); }
87 expr = Expr::app(expr, rewritten_arg);
88 }
89 return expr
90 }
91 if let Some((refinements,arg_depth)) = refinements.as_ref() {
94 if let Some(idx) = refinements.iter().position(|r| *r == unshifted_id) {
95 return Expr::var(total_depth - arg_depth + idx as i32); }
99 }
100
101
102 match &shared.node_of_id[usize::from(unshifted_id)] {
103 Lambda::Prim(p) => Expr::prim(*p),
104 Lambda::Var(i) => {
105 let mut j = *i;
106 for rule in shift_rules.iter() {
107 if total_depth - i <= rule.depth_cutoff {
110 j += rule.shift;
111 }
112 }
113 if let Some((refinements,arg_depth)) = refinements.as_ref() {
114 if j >= *arg_depth {
117 j += refinements.len() as i32;
119 }
120 }
121 assert!(j >= 0, "{}", pattern.to_expr(shared));
122 Expr::var(j)
123 }, Lambda::App([unshifted_f,unshifted_x]) => {
125 Expr::app(
126 helper(pattern, shared, *unshifted_f, total_depth, shift_rules, inv_name, refinements),
127 helper(pattern, shared, *unshifted_x, total_depth, shift_rules, inv_name, refinements),
128 )
129 },
130 Lambda::Lam([unshifted_b]) => {
131 Expr::lam(helper(pattern, shared, *unshifted_b, total_depth + 1, shift_rules, inv_name, refinements))
132 },
133 Lambda::IVar(_) => {
134 panic!("attempted to rewrite with an ivar");
135 },
136 _ => unreachable!(),
137 }
138 }
139
140 let shift_rules = &mut vec![];
141 let rewritten_exprs: Vec<Expr> = shared.roots.iter().map(|root| {
142 helper(pattern, shared, *root, 0, shift_rules, inv_name, None)
143 }).collect();
144
145 if !shared.cfg.no_mismatch_check && !shared.cfg.utility_by_rewrite {
146 assert_eq!(
147 shared.root_idxs_of_task.iter().map(|root_idxs|
148 root_idxs.iter().map(|idx| rewritten_exprs[*idx].cost()).min().unwrap()
149 ).sum::<i32>(),
150 shared.init_cost - pattern.util_calc.util,
151 "\n{}\n", pattern.info(shared)
152 );
153 }
154
155 rewritten_exprs
156}
157
158
159
160
161
162#[derive(Debug, Clone, Eq, PartialEq, Hash, PartialOrd, Ord)]
164struct PtrInvention {
165 pub body:Id, pub arity: usize, pub name: String
168}
169impl PtrInvention {
170 pub fn new(body:Id, arity: usize, name: String) -> Self {
171 PtrInvention {
172 body,
173 arity,
174 name
175 }
176 }
177}
178
179pub fn rewrite_with_inventions(
181 e: Expr,
182 invs: &[Invention]
183) -> Expr {
184 let mut egraph = EGraph::default();
185 let root = egraph.add_expr(&e.into());
186 rewrite_with_inventions_egraph(root, invs, &mut egraph)
187}
188
189pub fn rewrite_with_invention(
198 e: Expr,
199 inv: &Invention,
200) -> Expr {
201 let mut egraph = EGraph::default();
202 let root = egraph.add_expr(&e.into());
203 rewrite_with_invention_egraph(root, inv, &mut egraph)
204}
205
206pub fn rewrite_with_inventions_egraph(
208 root: Id,
209 invs: &[Invention],
210 egraph: &mut EGraph,
211) -> Expr {
212 let mut root = root;
213 for inv in invs.iter() {
214 let expr = rewrite_with_invention_egraph(root, inv, egraph);
215 root = egraph.add_expr(&expr.into());
216 }
217 extract(root,egraph)
218}
219
220pub fn rewrite_with_invention_egraph(
226 root: Id,
227 inv: &Invention,
228 egraph: &mut EGraph,
229) -> Expr {
230 let inv: PtrInvention = PtrInvention::new(egraph.add_expr(&inv.body.clone().into()), inv.arity, inv.name.clone());
231
232 let treenodes = topological_ordering(root, egraph);
233
234 assert!(!treenodes.iter().any(|n| egraph[*n].nodes[0] == Lambda::Prim(Symbol::from(&inv.name))),
235 "Invention {} already in tree", inv.name);
236
237 let mut nodecost_of_treenode: FxHashMap<Id,NodeCost> = Default::default();
238
239 for treenode in treenodes.iter() {
240 let node = egraph[*treenode].nodes[0].clone();
244
245 let mut nodecost = NodeCost::new(egraph[*treenode].data.inventionless_cost);
246
247 if let Some(args) = match_expr_with_inv(*treenode, &inv, &mut nodecost_of_treenode, egraph) {
249 let cost: i32 =
250 COST_TERMINAL + COST_NONTERMINAL * inv.arity as i32 + args.iter()
253 .map(|id| nodecost_of_treenode[id]
254 .cost_under_inv(&inv)) .sum::<i32>(); nodecost.new_cost_under_inv(inv.clone(), cost, Some(args));
257 }
258
259
260 match node {
262 Lambda::IVar(_) => { unreachable!() }
263 Lambda::Var(_) => {},
264 Lambda::Prim(_) => {},
265 Lambda::App([f,x]) => {
266 let f_nodecost = &nodecost_of_treenode[&f];
267 let x_nodecost = &nodecost_of_treenode[&x];
268
269 let fcost = f_nodecost.cost_under_inv(&inv);
272 let xcost = x_nodecost.cost_under_inv(&inv);
273 let cost = COST_NONTERMINAL+fcost+xcost;
274 nodecost.new_cost_under_inv(inv.clone(), cost, None);
275 }
276 Lambda::Lam([b]) => {
277 let b_nodecost = &nodecost_of_treenode[&b];
279 let bcost = b_nodecost.cost_under_inv(&inv);
280 nodecost.new_cost_under_inv(inv.clone(), bcost + COST_NONTERMINAL, None);
281 }
282 Lambda::Programs(roots) => {
283 let cost = roots.iter().map(|root| {
285 nodecost_of_treenode[root].cost_under_inv(&inv)
286 }).sum();
287 nodecost.new_cost_under_inv(inv.clone(), cost, None);
288 }
289 }
290
291 nodecost_of_treenode.insert(*treenode, nodecost);
292 }
293
294 extract_from_nodecosts(root, &inv, &nodecost_of_treenode, egraph)
296}
297
298fn extract_from_nodecosts(
299 root: Id,
300 inv: &PtrInvention,
301 nodecost_of_treenode: &FxHashMap<Id,NodeCost>,
302 egraph: &EGraph,
303) -> Expr {
304
305 let target_cost = nodecost_of_treenode[&root].cost_under_inv(inv);
306
307 if let Some((inv,_cost,args)) = nodecost_of_treenode[&root].top_invention() {
308 if let Some(args) = args {
309 let mut expr = Expr::prim(inv.name.clone().into());
311 for arg in args.iter() {
313 let arg_expr = extract_from_nodecosts(*arg, &inv, nodecost_of_treenode, egraph);
314 expr = Expr::app(expr,arg_expr);
315 }
316 assert_eq!(target_cost,expr.cost());
317 expr
318 } else {
319 let expr: Expr = match &egraph[root].nodes[0] {
321 Lambda::Prim(_) | Lambda::Var(_) | Lambda::IVar(_) => {unreachable!()},
322 Lambda::App([f,x]) => {
323 let f_expr = extract_from_nodecosts(*f, &inv, nodecost_of_treenode, egraph);
324 let x_expr = extract_from_nodecosts(*x, &inv, nodecost_of_treenode, egraph);
325 Expr::app(f_expr,x_expr)
326 },
327 Lambda::Lam([b]) => {
328 let b_expr = extract_from_nodecosts(*b, &inv, nodecost_of_treenode, egraph);
329 Expr::lam(b_expr)
330 }
331 Lambda::Programs(roots) => {
332 let root_exprs: Vec<Expr> = roots.iter()
333 .map(|r| extract_from_nodecosts(*r, &inv, nodecost_of_treenode, egraph))
334 .collect();
335 Expr::programs(root_exprs)
336 }
337 };
338 assert_eq!(target_cost,expr.cost());
339 expr
340 }
341 } else {
342 let expr = extract(root, egraph);
344 assert_eq!(target_cost,expr.cost());
345 expr
346 }
347}
348
349#[derive(Debug,Clone)]
352struct NodeCost {
353 inventionless_cost: i32,
354 inventionful_cost: FxHashMap<PtrInvention, (i32,Option<Vec<Id>>)>, }
356
357impl NodeCost {
358 fn new(inventionless_cost: i32) -> Self {
359 Self {
360 inventionless_cost,
361 inventionful_cost: FxHashMap::default()
362 }
363 }
364 fn cost_under_inv(&self, inv: &PtrInvention) -> i32 {
366 self.inventionful_cost.get(inv).map(|x|x.0).unwrap_or(self.inventionless_cost)
367 }
368 fn new_cost_under_inv(&mut self, inv: PtrInvention, cost:i32, args: Option<Vec<Id>>) {
371 if cost < self.inventionless_cost
372 && (!self.inventionful_cost.contains_key(&inv) || cost < self.inventionful_cost[&inv].0)
373 {
374 self.inventionful_cost.insert(inv, (cost,args));
375 }
376 }
377 #[allow(dead_code)] fn top_inventions(&self) -> Vec<PtrInvention> {
380 let mut top_inventions: Vec<PtrInvention> = self.inventionful_cost.keys().cloned().collect();
381 top_inventions.sort_by(|a,b| self.inventionful_cost[a].0.cmp(&self.inventionful_cost[b].0));
382 top_inventions
383 }
384 fn top_invention(&self) -> Option<(PtrInvention,i32,Option<Vec<Id>>)> {
386 self.inventionful_cost.iter().min_by_key(|(_k,v)| v.0).map(|(k,v)| (k.clone(),v.0,v.1.clone()))
387 }
388}
389
390
391fn match_expr_with_inv(
392 root: Id,
393 inv: &PtrInvention,
394 best_inventions_of_treenode: &mut FxHashMap<Id, NodeCost>,
395 egraph: &mut EGraph,
396) -> Option<Vec<Id>> {
397 let mut args: Vec<Option<Id>> = vec![None;inv.arity];
398 let threadables = threadables_of_inv(inv.clone(), egraph);
399 if match_expr_with_inv_rec(root, inv.body, 0, &mut args, &threadables, best_inventions_of_treenode, egraph) {
400 assert!(args.iter().all(|x| x.is_some()), "{:?}\n{}\n{}", args, extract(root,egraph), extract(inv.body,egraph)); Some(args.iter().map(|arg| arg.unwrap()).collect())
402 } else {
403 None
404 }
405}
406
407fn match_expr_with_inv_rec(
408 root: Id,
409 inv: Id,
410 depth: i32,
411 args: &mut [Option<Id>],
412 threadables: &FxHashSet<Id>,
413 best_inventions_of_treenode: &mut FxHashMap<Id, NodeCost>,
414 egraph: &mut EGraph,
415) -> bool {
416 match (&egraph[root].nodes[0].clone(), &egraph[inv].nodes[0].clone()) { (Lambda::Prim(p), Lambda::Prim(q)) => { p == q },
423 (Lambda::Var(i), Lambda::Var(j)) => { i == j },
424 (root_node, Lambda::App([g,y])) if threadables.contains(&inv) => {
425 let internal_free_vars: FxHashSet<i32> = egraph[root].data.free_vars.iter().filter(|i| **i < depth).cloned().collect();
430 let num_to_thread = internal_free_vars.len() as i32;
431 if internal_free_vars == egraph[inv].data.free_vars {
432 if let Lambda::App([f,x]) = root_node {
438 let cloned_args: Vec<_> = args.to_vec();
439 if match_expr_with_inv_rec(*f, *g, depth, args, threadables, best_inventions_of_treenode, egraph)
440 && match_expr_with_inv_rec(*x, *y, depth, args, threadables, best_inventions_of_treenode, egraph) {
441 return true;
442 }
443 args.clone_from_slice(cloned_args.as_slice());
444 }
445
446 let mut arg = root;
451 for i in 0..depth {
452 if egraph[inv].data.free_vars.contains(&i) {
453 arg = egraph.add(Lambda::Lam([arg]));
455 }
456 arg = shift(arg, -1, egraph, &mut None).unwrap();
457 }
458
459 if !best_inventions_of_treenode.contains_key(&arg) {
461 let mut cloned = best_inventions_of_treenode[&root].clone();
462 cloned.inventionless_cost += COST_NONTERMINAL * num_to_thread;
463 cloned.inventionful_cost.iter_mut().for_each(|(_key, val)| {val.0 += COST_NONTERMINAL * num_to_thread; val.1 = None});
465 best_inventions_of_treenode.insert(arg,cloned);
466 }
467
468 let ivar = *egraph[inv].data.free_ivars.iter().next().unwrap() as usize;
469
470 if let Some(v) = args[ivar] {
472 arg == v } else {
474 args[ivar] = Some(arg);
475 true
477 }
478 } else {
479 if let Lambda::App([f,x]) = root_node {
481 return match_expr_with_inv_rec(*f, *g, depth, args, threadables, best_inventions_of_treenode, egraph)
482 && match_expr_with_inv_rec(*x, *y, depth, args, threadables, best_inventions_of_treenode, egraph)
483 }
484 false
485 }
486 },
487 (Lambda::App([f,x]), Lambda::App([g,y])) => {
488 match_expr_with_inv_rec(*f, *g, depth, args, threadables, best_inventions_of_treenode, egraph)
490 && match_expr_with_inv_rec(*x, *y, depth, args, threadables, best_inventions_of_treenode, egraph)
491 }
492 (Lambda::Lam([b]), Lambda::Lam([c])) => {
493 match_expr_with_inv_rec(*b, *c, depth+1, args, threadables, best_inventions_of_treenode, egraph)
494 },
495 (_, Lambda::IVar(j)) => {
496 let shifted_root: Id = if egraph[root].data.free_vars.is_empty() {
499 root
501 } else if egraph[root].data.free_vars.iter().min().unwrap() - depth >= 0 {
502 fn shift_and_fix(node: Id, depth: i32, best_inventions_of_treenode: &mut FxHashMap<Id,NodeCost>, egraph: &mut EGraph) -> Id {
506 let shifted_node = shift(node, -depth, egraph, &mut None).unwrap();
507 if best_inventions_of_treenode.contains_key(&shifted_node) {
508 return shifted_node; }
510 let mut cloned = best_inventions_of_treenode[&node].clone();
511 cloned.inventionful_cost.iter_mut().for_each(|(_key, val)| {
518 if let Some(args) = &mut val.1 {
519 args.iter_mut().for_each(|arg| *arg = shift_and_fix(*arg, depth, best_inventions_of_treenode, egraph));
520 }
521 });
522 best_inventions_of_treenode.insert(shifted_node,cloned);
523 shifted_node
524 }
525 shift_and_fix(root, depth, best_inventions_of_treenode, egraph)
526 } else {
527 return false };
529
530 if let Some(v) = args[*j as usize] {
531 shifted_root == v } else {
533 args[*j as usize] = Some(shifted_root);
534 true
536 }
537 },
538 _ => { false }
539 }
540}
541
542fn threadables_of_inv(inv: PtrInvention, egraph: &EGraph) -> FxHashSet<Id> {
543 let mut threadables: FxHashSet<Id> = Default::default();
547 let nodes = topological_ordering(inv.body, egraph);
548 for node in nodes {
549 if let Lambda::App([f,x]) = egraph[node].nodes[0] {
550 if matches!(egraph[x].nodes[0], Lambda::Var(_))
551 && (matches!(egraph[f].nodes[0], Lambda::IVar(_)) || threadables.contains(&f))
552 {
553 threadables.insert(node);
554 }
556 }
557 }
558 threadables
559}