stitch_core/bottom_up_synthesis/
mod.rs1
2use crate::*;
3use rustc_hash::{FxHashMap};
4use std::{time::Instant};
5use clap::{Parser};
6use serde::Serialize;
7use itertools::Itertools;
8#[derive(Parser, Debug, Serialize)]
13#[clap(name = "Bottom-up synthesis")]
14pub struct BottomUpConfig {
15 #[clap(long, default_value = "1")]
17 pub cost_step: usize,
18
19 #[clap(short = 'c', long, default_value = "10")]
21 pub max_cost: usize,
22
23 #[clap(long)]
25 pub print_found: bool,
26}
27
28#[derive(Clone)]
29pub struct Found<D: Domain> {
30 val: Val<D>, id: Id, cost: usize, }
34
35#[derive(Clone)]
36pub struct FoundExpr<D: Domain> {
37 val: Val<D>, expr: Expr, cost: usize, }
41
42impl <D: Domain> Found<D> {
43 fn new(val: Val<D>, id: Id, cost: usize) -> Self {
44 Found {
45 val,
46 id,
47 cost,
48 }
49 }
50}
51
52impl <D: Domain> FoundExpr<D> {
53 pub fn new(val: Val<D>, expr: Expr, cost: usize) -> Self {
54 FoundExpr {
55 val,
56 expr,
57 cost,
58 }
59 }
60}
61
62#[derive(Clone, Debug, Default)]
63struct Stats {
64 num_eval_ok: usize,
65 num_eval_err: usize,
66 num_not_seen: usize,
67 num_yes_seen: usize,
68 num_yes_seen_and_was_better: usize,
69}
70
71pub fn bottom_up<D: Domain>(
72 initial: &[FoundExpr<D>],
74 fns: &[(DSLEntry<D>,usize)],
75 cfg: &BottomUpConfig,
76) {
77
78 let fns: Vec<(DSLEntry<D>,usize)> = fns.iter().filter(|(entry, _)| entry.arity > 0).cloned().collect();
79
80 let tstart = Instant::now();
81 let mut stats: Stats = Default::default();
82
83 let mut curr_cost = cfg.cost_step;
84 let mut vals_of_type: FxHashMap<Type,Vec<Found<D>>> = Default::default();
85
86 let mut handle: Expr = {
88 let dsl_fns_expr: Expr = Expr::programs(fns.iter().map(|(entry,_)| Expr::prim(entry.name)).collect());
89 let init_vals_expr: Expr = Expr::programs(initial.iter().map(|found_expr| found_expr.expr.clone()).collect());
90 Expr::programs(vec![dsl_fns_expr,init_vals_expr])
91 };
92
93 let mut seen: FxHashMap<Val<D>,usize> = FxHashMap::default();
94
95 let init_val_ids: Vec<Id> = handle.get(handle.get_root().children()[1]).children().to_vec();
97
98 println!("Productions:");
99 for (f,cost) in fns.iter() {
100 println!("(cost {}) {} :: {}", cost, f.name, f.tp);
101 }
102
103 println!("Initial:");
104 for (i,found_expr) in initial.iter().enumerate() {
105 let id = init_val_ids[i]; let found = Found::new(found_expr.val.clone(), id, found_expr.cost);
107 let tp = found_expr.expr.infer::<D>(None, &mut Context::empty(), &mut Default::default()).unwrap();
108
109 println!("(cost {}) {} :: {} => {:?}", found.cost, handle.to_string_uncurried(Some(found.id)), tp, found.val);
110
111 vals_of_type.entry(tp).or_default().push(found.clone());
112 seen.insert(found.val.clone(), found.cost);
113 }
114
115
116
117 while curr_cost < cfg.max_cost {
118 vals_of_type.values_mut().for_each(|vals| {
120 vals.sort_by(|a,b| a.cost.cmp(&b.cost));
121 });
123
124 let seen_types: Vec<Type> = vals_of_type.keys().cloned().collect();
125
126 println!("new curr cost: {}", curr_cost);
127 let mut new_vals_of_type: FxHashMap<Type,Vec<Found<D>>> = Default::default();
128
129
130 for (i_fn, (dsl_entry, fn_cost)) in fns.iter().enumerate() {
131 for (found_args, tp, cost) in ArgChoiceIterator::new(&vals_of_type, &seen_types, &dsl_entry.tp, *fn_cost, curr_cost, curr_cost - cfg.cost_step) {
134 let args: Vec<LazyVal<D>> = found_args.iter().map(|&f| LazyVal::new_strict(f.val.clone())).collect();
135 if let Ok(val) = (D::lookup_fn_ptr(dsl_entry.name)) (args, &mut handle.as_eval(None)) {
137 stats.num_eval_ok += 1;
138 match seen.get(&val) {
139 None => {
140 stats.num_not_seen += 1;
141 let mut id = Id::from(i_fn); for arg in found_args.iter() {
143 handle.nodes.push(Lambda::App([id,arg.id]));
144 id = Id::from(handle.nodes.len()-1);
145 }
146 new_vals_of_type.entry(tp).or_default().push(Found::new(val, id, cost));
147 }
148 Some(&old_cost) => {
149 stats.num_yes_seen += 1;
150 if old_cost > cost {
151 let mut id = Id::from(i_fn); for arg in found_args.iter() {
153 handle.nodes.push(Lambda::App([id,arg.id]));
154 id = Id::from(handle.nodes.len()-1);
155 }
156 new_vals_of_type.entry(tp).or_default().push(Found::new(val, id, cost));
157
158 } else {
159 stats.num_yes_seen_and_was_better += 1;
160 }
161 }
162 }
163
164 } else {
165 stats.num_eval_err += 1;
167 }
168 }
169 }
170
171 for (tp, new_vals) in new_vals_of_type.into_iter() {
173 for found in new_vals.into_iter() {
174 match seen.get(&found.val) {
175 None => {
176 seen.insert(found.val.clone(),found.cost);
177 vals_of_type.entry(tp.clone()).or_default().push(found.clone());
178 if cfg.print_found{
179 println!("(cost {}) {} :: {} => {:?}", found.cost, handle.to_string_uncurried(Some(found.id)), tp, found.val);
180 }
181 }
182 Some(&old_cost) => {
183 if old_cost > found.cost {
184 *seen.get_mut(&found.val).unwrap() = found.cost;
185 vals_of_type.get_mut(&tp).unwrap().retain(|f| f.val != found.val);
190
191 vals_of_type.entry(tp.clone()).or_default().push(found.clone());
193 if cfg.print_found{
194 println!("(cost {}) {} :: {:?} -> {:?}", found.cost, handle.to_string_uncurried(Some(found.id)), tp, found.val);
195 }
196 }
197 }
198 }
199 }
200 }
201
202
203
204 curr_cost += cfg.cost_step;
205 }
206
207 println!("reached max cost");
209 println!("Time: {}ms",tstart.elapsed().as_millis());
210 println!("{:?}",stats);
211 println!("num found: {}",seen.len());
212 println!("num found per ms: {:.2}", seen.len() as f64 / tstart.elapsed().as_millis() as f64);
213 println!("num eval total: {}",stats.num_eval_ok+stats.num_eval_err);
214 println!("% eval ok: {:.2}%", stats.num_eval_ok as f64 / (stats.num_eval_ok + stats.num_eval_err) as f64 * 100.0);
215 println!("num eval per ms: {:.2}",(stats.num_eval_ok+stats.num_eval_err) as f64 / tstart.elapsed().as_millis() as f64);
216 println!("num found by type:\n\t{}", vals_of_type.iter().map(|(ty,vals)| format!("{}: {}", ty, vals.len())).collect::<Vec<_>>().join("\n\t"));
217
218 }
245
246
247struct ArgChoiceIterator<'a, D: Domain> {
248 args: Vec<ArgState<'a,D>>,
249 arg_tp_iter: Box<dyn Iterator<Item=(Vec<(&'a Type, &'a Type)>, Type)> + 'a>,
250 vals_of_type: &'a FxHashMap<Type,Vec<Found<D>>>, return_tp: Option<Type>,
252 fn_cost: usize,
253 max_cost: usize,
254 prev_max_cost: usize,
255 prev_idx_to_inc: usize,
256}
257
258struct ArgState<'a, D: Domain> {
259 i_vals: usize,
260 tp: &'a Type,
261 vals: &'a [Found<D>]
262}
263
264impl <'a, D: Domain> ArgChoiceIterator<'a,D> {
279 fn new(vals_of_type: &'a FxHashMap<Type,Vec<Found<D>>>, seen_types: &'a [Type], fn_tp: &'a Type, fn_cost: usize, max_cost: usize, prev_max_cost: usize) -> Self {
280 assert!( max_cost > prev_max_cost);
281 assert!(fn_tp.arity() > 0); let mut arg_tp_iter = fn_tp.iter_args().map(|arg_tp|
285 seen_types.iter()
286 .filter(move |seen_tp| Context::empty().unify(seen_tp, arg_tp).is_ok()) .map(move |seen_tp| (seen_tp,arg_tp))
288 ).multi_cartesian_product()
289 .filter_map(move |seen_arg_tps|{
290 let mut ctx = Context::empty();
292 if !seen_arg_tps.iter().all(|(seen_tp, arg_tp)| {
293 let ty = arg_tp.apply(&mut ctx);
294 ctx.unify(seen_tp, &ty).is_ok()
295 }) {
296 None } else {
298 Some((seen_arg_tps, fn_tp.return_type().apply(&mut ctx)))
300 }
301 });
302
303 let (args, return_tp) = arg_tp_iter.next().map(|(seen_arg_tps, return_tp)| (seen_arg_tps.iter().map(|(seen_tp,_)| ArgState { i_vals: 0, tp: seen_tp, vals: &vals_of_type[seen_tp]}).collect(), Some(return_tp))).unwrap_or((vec![],None));
305
306 ArgChoiceIterator {
307 args,
308 arg_tp_iter: Box::new(arg_tp_iter),
309 vals_of_type,
310 return_tp,
311 fn_cost,
312 max_cost,
313 prev_max_cost,
314 prev_idx_to_inc: 0,
315 }
316 }
317 fn next_tps(&mut self) -> bool {
318 match self.arg_tp_iter.next() {
319 Some((seen_arg_tps, return_tp)) => {
320 for (arg, (seen_tp,_)) in self.args.iter_mut().zip(seen_arg_tps.iter()) {
321 arg.i_vals = 0;
322 arg.tp = seen_tp;
323 arg.vals = &self.vals_of_type[seen_tp];
324 }
325 self.return_tp = Some(return_tp);
326 true
327 },
328 None => {
329 self.return_tp = None;
330 false
331 },
332 }
333 }
334 fn rollover(&mut self) {
335 let mut carry = false;
336 for (i,arg) in self.args.iter_mut().enumerate() {
337 if carry {
338 arg.i_vals += 1; self.prev_idx_to_inc = i;
340 carry = false;
341 }
342 if arg.i_vals >= arg.vals.len() {
343 arg.i_vals = 0;
344 carry = true;
345 }
346 }
347 if carry {
348 self.args.last_mut().unwrap().i_vals = self.args.last().unwrap().vals.len();
349 }
350 }
351}
352
353
354impl<'a, D: Domain> Iterator for ArgChoiceIterator<'a, D> {
355 type Item = (Vec<&'a Found<D>>, Type, usize);
356
357 fn next(&mut self) -> Option<Self::Item> {
358 if self.return_tp == None {
359 return None }
361
362 loop {
363 if self.args.last().unwrap().i_vals >= self.args.last().unwrap().vals.len() {
365 if self.next_tps() {
366 continue
367 } else {
368 return None;
369 }
370 }
371
372 let cost: usize = self.fn_cost + self.args.iter().map(|arg| arg.vals[arg.i_vals].cost).sum::<usize>();
376
377 if cost > self.max_cost {
378 self.args[self.prev_idx_to_inc].i_vals = self.args[self.prev_idx_to_inc].vals.len();
380 debug_assert!(self.args[..self.prev_idx_to_inc].iter().all(|arg| arg.i_vals == 0));
381 self.rollover();
382 continue;
383 }
384
385 if cost <= self.prev_max_cost {
387 self.args.first_mut().unwrap().i_vals += 1;
388 self.prev_idx_to_inc = 0;
389 self.rollover();
390 continue;
391 }
392
393 let res: Vec<&Found<D>> = self.args.iter().map(|arg| &arg.vals[arg.i_vals]).collect();
394
395 self.args.first_mut().unwrap().i_vals += 1;
397 self.prev_idx_to_inc = 0;
398 self.rollover();
399
400
401 return Some((res, self.return_tp.clone().unwrap(), cost))
402 }
403 }
404}