stitch_core/bottom_up_synthesis/
mod.rs

1
2use crate::*;
3use rustc_hash::{FxHashMap};
4use std::{time::Instant};
5use clap::{Parser};
6use serde::Serialize;
7use itertools::Itertools;
8// use serde_json::json;
9
10
11/// Bottom-up synthesis
12#[derive(Parser, Debug, Serialize)]
13#[clap(name = "Bottom-up synthesis")]
14pub struct BottomUpConfig {
15    /// How big of a step to increase cost by for each round of bottom up
16    #[clap(long, default_value = "1")]
17    pub cost_step: usize,
18
19    /// Max cost to enumerate to
20    #[clap(short = 'c', long, default_value = "10")]
21    pub max_cost: usize,
22
23    /// print all exprs found at end
24    #[clap(long)]
25    pub print_found: bool,
26}
27
28#[derive(Clone)]
29pub struct Found<D: Domain> {
30    val: Val<D>, // value that was found
31    id: Id, // expr that constructs it
32    cost: usize, // cost of constructing it
33}
34
35#[derive(Clone)]
36pub struct FoundExpr<D: Domain> {
37    val: Val<D>, // value that was found
38    expr: Expr, // expr that constructs it
39    cost: usize, // cost of constructing it
40}
41
42impl <D: Domain> Found<D> {
43    fn new(val: Val<D>, id: Id, cost: usize) -> Self {
44        Found {
45            val,
46            id,
47            cost,
48        }
49    }
50}
51
52impl <D: Domain> FoundExpr<D> {
53    pub fn new(val: Val<D>, expr: Expr, cost: usize) -> Self {
54        FoundExpr {
55            val,
56            expr,
57            cost,
58        }
59    }
60}
61
62#[derive(Clone, Debug, Default)]
63struct Stats {
64    num_eval_ok: usize,
65    num_eval_err: usize,
66    num_not_seen: usize,
67    num_yes_seen: usize,
68    num_yes_seen_and_was_better: usize,
69}
70
71pub fn bottom_up<D: Domain>(
72    // handle: &mut Evaluator<D>,
73    initial: &[FoundExpr<D>],
74    fns: &[(DSLEntry<D>,usize)],
75    cfg: &BottomUpConfig,
76) {
77
78    let fns: Vec<(DSLEntry<D>,usize)>  = fns.iter().filter(|(entry, _)| entry.arity > 0).cloned().collect();
79
80    let tstart = Instant::now();
81    let mut stats: Stats = Default::default();
82
83    let mut curr_cost = cfg.cost_step;
84    let mut vals_of_type: FxHashMap<Type,Vec<Found<D>>> = Default::default();
85
86    // add each dsl fn so the ith dsl fn is expr.nodes[i] // todo programs() is a gross way to do this
87    let mut handle: Expr = {
88        let dsl_fns_expr: Expr = Expr::programs(fns.iter().map(|(entry,_)| Expr::prim(entry.name)).collect());
89        let init_vals_expr: Expr = Expr::programs(initial.iter().map(|found_expr| found_expr.expr.clone()).collect());    
90        Expr::programs(vec![dsl_fns_expr,init_vals_expr])
91    };
92
93    let mut seen: FxHashMap<Val<D>,usize> = FxHashMap::default();
94
95    // ids for the exprs passed in as `initial`
96    let init_val_ids: Vec<Id> = handle.get(handle.get_root().children()[1]).children().to_vec();
97
98    println!("Productions:");
99    for (f,cost) in fns.iter() {
100        println!("(cost {}) {} :: {}", cost, f.name, f.tp);
101    }
102
103    println!("Initial:");
104    for (i,found_expr) in initial.iter().enumerate() {
105        let id = init_val_ids[i]; // ith child of programs node
106        let found = Found::new(found_expr.val.clone(), id, found_expr.cost);
107        let tp = found_expr.expr.infer::<D>(None, &mut Context::empty(), &mut Default::default()).unwrap();
108
109        println!("(cost {}) {} :: {} => {:?}", found.cost, handle.to_string_uncurried(Some(found.id)), tp, found.val);
110
111        vals_of_type.entry(tp).or_default().push(found.clone());
112        seen.insert(found.val.clone(), found.cost);
113    }
114
115
116
117    while curr_cost < cfg.max_cost {
118        // sort by the cost
119        vals_of_type.values_mut().for_each(|vals| {
120            vals.sort_by(|a,b| a.cost.cmp(&b.cost));
121            // vals.dedup_by(|a,b| a.val == b.val);
122        });
123
124        let seen_types: Vec<Type> = vals_of_type.keys().cloned().collect();
125
126        println!("new curr cost: {}", curr_cost);
127        let mut new_vals_of_type: FxHashMap<Type,Vec<Found<D>>> = Default::default();
128
129
130        for (i_fn, (dsl_entry, fn_cost)) in fns.iter().enumerate() {
131            // println!("trying fn: {}", dsl_entry.name);
132
133            for (found_args, tp, cost) in ArgChoiceIterator::new(&vals_of_type, &seen_types, &dsl_entry.tp, *fn_cost, curr_cost, curr_cost - cfg.cost_step) {
134                let args: Vec<LazyVal<D>> = found_args.iter().map(|&f| LazyVal::new_strict(f.val.clone())).collect();
135                // println!("trying ({} {})", dsl_entry.name, found_cfg.iter().map(|arg| format!("{:?}",arg.val)).collect::<Vec<_>>().join(" "));
136                if let Ok(val) = (D::lookup_fn_ptr(dsl_entry.name)) (args, &mut handle.as_eval(None)) {
137                    stats.num_eval_ok += 1;
138                    match seen.get(&val) {
139                        None => {
140                            stats.num_not_seen += 1;
141                            let mut id = Id::from(i_fn); // assumes we constructed the ith fn primitive to be the ith element in handle.expr.nodes
142                            for arg in found_args.iter() {
143                                handle.nodes.push(Lambda::App([id,arg.id]));
144                                id = Id::from(handle.nodes.len()-1);
145                            }
146                            new_vals_of_type.entry(tp).or_default().push(Found::new(val, id, cost));
147                        }
148                        Some(&old_cost) => {
149                            stats.num_yes_seen += 1;
150                            if old_cost > cost {
151                                let mut id = Id::from(i_fn); // assumes we constructed the ith fn primitive to be the ith element in handle.expr.nodes
152                                for arg in found_args.iter() {
153                                    handle.nodes.push(Lambda::App([id,arg.id]));
154                                    id = Id::from(handle.nodes.len()-1);
155                                }
156                                new_vals_of_type.entry(tp).or_default().push(Found::new(val, id, cost));
157        
158                            } else {
159                                stats.num_yes_seen_and_was_better += 1;
160                            }
161                        }
162                    }
163
164                } else {
165                    // Err from execution, discard
166                    stats.num_eval_err += 1;
167                }
168            }
169        }
170
171        // deposit new vals into vals_of_type
172        for (tp, new_vals) in new_vals_of_type.into_iter() {
173            for found in new_vals.into_iter() {
174                match seen.get(&found.val) {
175                    None => {
176                        seen.insert(found.val.clone(),found.cost);
177                        vals_of_type.entry(tp.clone()).or_default().push(found.clone());
178                        if cfg.print_found{
179                            println!("(cost {}) {} :: {} => {:?}", found.cost, handle.to_string_uncurried(Some(found.id)), tp, found.val);
180                        }
181                    }
182                    Some(&old_cost) => {
183                        if old_cost > found.cost {
184                            *seen.get_mut(&found.val).unwrap() = found.cost;
185                            // removes old value
186                            // todo this is prob v slow as implemented, could do faster with a binary search by cost or something which I guess works since we do assume vals_of_type is sorted by cost
187                            // HOWEVER that invariant gets broken during this process so we should actually switch to doing a binary insertion if we do this.
188                            // vals_of_type.get_mut(tp).unwrap().partition_point(|found| found.cost)
189                            vals_of_type.get_mut(&tp).unwrap().retain(|f| f.val != found.val);
190
191                            // add new value
192                            vals_of_type.entry(tp.clone()).or_default().push(found.clone());
193                            if cfg.print_found{
194                                println!("(cost {}) {} :: {:?} -> {:?}", found.cost, handle.to_string_uncurried(Some(found.id)), tp, found.val);
195                            }  
196                        }
197                    }
198                }
199            }
200        }
201
202
203
204        curr_cost += cfg.cost_step;
205    }
206
207    //todo add a sanity check that the length of seen equals the lengths of all val arrays. i bet theres an error and that wont be true lol
208    println!("reached max cost");
209    println!("Time: {}ms",tstart.elapsed().as_millis());
210    println!("{:?}",stats);
211    println!("num found: {}",seen.len());
212    println!("num found per ms: {:.2}", seen.len() as f64 / tstart.elapsed().as_millis() as f64);
213    println!("num eval total: {}",stats.num_eval_ok+stats.num_eval_err);
214    println!("% eval ok: {:.2}%", stats.num_eval_ok as f64 / (stats.num_eval_ok + stats.num_eval_err) as f64 * 100.0);
215    println!("num eval per ms: {:.2}",(stats.num_eval_ok+stats.num_eval_err) as f64 / tstart.elapsed().as_millis() as f64);
216    println!("num found by type:\n\t{}", vals_of_type.iter().map(|(ty,vals)| format!("{}: {}", ty, vals.len())).collect::<Vec<_>>().join("\n\t"));
217
218    // write a json out with everything that was found
219    // let out = json!({
220    //     "stats": {
221    //         "num_eval_ok": num_eval_ok,
222    //         "num_eval_err": num_eval_err,
223    //         "num_eval_total": num_eval_ok+num_eval_err,
224    //         "percent_eval_ok": num_eval_ok as f64 / (num_eval_ok + num_eval_err) as f64 * 100.0,
225    //         "num_eval_per_ms": (num_eval_ok+num_eval_err) as f64 / tstart.elapsed().as_millis() as f64,
226    //         "num_not_seen": num_not_seen,
227    //         "num_yes_seen": num_yes_seen,
228    //         "num_yes_seen_and_was_better": num_yes_seen_and_was_better,
229    //     },
230    // });
231
232
233    // let out_path = cfg.out;
234    // if let Some(out_path_dir) = out_path.parent() {
235    //     if !out_path_dir.exists() {
236    //         std::fs::create_dir_all(out_path_dir).unwrap();
237    //     }
238    // }
239    // std::fs::write(out_path, serde_json::to_string_pretty(&out).unwrap()).unwrap();
240    // println!("Wrote to {:?}", out_path);
241
242
243
244}
245
246
247struct ArgChoiceIterator<'a, D: Domain> {
248    args: Vec<ArgState<'a,D>>,
249    arg_tp_iter: Box<dyn Iterator<Item=(Vec<(&'a Type, &'a Type)>, Type)> + 'a>,
250    vals_of_type: &'a FxHashMap<Type,Vec<Found<D>>>, // vals[i] is the list of found vals for the ith arg
251    return_tp: Option<Type>,
252    fn_cost: usize,
253    max_cost: usize,
254    prev_max_cost: usize,
255    prev_idx_to_inc: usize,
256}
257
258struct ArgState<'a, D: Domain> {
259    i_vals: usize,
260    tp: &'a Type,
261    vals: &'a [Found<D>]
262}
263
264// struct ArgTypeIter<'a> {
265//     fn_tp: &'a Type,
266//     seen_types: &'a [Type],
267// }
268// impl<'a> Iterator for ArgTypeIter<'a> {
269//     type Item = (Vec<(&'a Type, &'a Type)>, Type);
270
271//     fn next(&mut self) -> Option<Self::Item> {
272
273//     }
274// }
275
276
277
278impl <'a, D: Domain> ArgChoiceIterator<'a,D> {
279    fn new(vals_of_type: &'a FxHashMap<Type,Vec<Found<D>>>, seen_types: &'a [Type], fn_tp: &'a Type, fn_cost: usize, max_cost:  usize, prev_max_cost: usize) -> Self {
280        assert!( max_cost > prev_max_cost);
281        assert!(fn_tp.arity() > 0); // we use the empty .args list as a sentinel
282        
283
284        let mut arg_tp_iter = fn_tp.iter_args().map(|arg_tp|
285            seen_types.iter()
286                      .filter(move |seen_tp| Context::empty().unify(seen_tp, arg_tp).is_ok()) // filter for ones that unify with the expected type
287                      .map(move |seen_tp| (seen_tp,arg_tp))
288            ).multi_cartesian_product()
289             .filter_map(move |seen_arg_tps|{
290                // unify all the args together in one context to see if they're all mutually compatible
291                let mut ctx = Context::empty();
292                if !seen_arg_tps.iter().all(|(seen_tp, arg_tp)| {
293                    let ty = arg_tp.apply(&mut ctx);
294                    ctx.unify(seen_tp, &ty).is_ok()
295                }) {
296                    None // at least one unify() failure
297                } else {
298                    // Some(seen_arg_tps)
299                    Some((seen_arg_tps, fn_tp.return_type().apply(&mut ctx)))
300                }
301             });
302
303        // initialize the `args` field or make it an empty vector as a sentinel in case theres no first item in `arg_tp_iter`
304        let (args, return_tp) = arg_tp_iter.next().map(|(seen_arg_tps, return_tp)| (seen_arg_tps.iter().map(|(seen_tp,_)| ArgState { i_vals: 0, tp: seen_tp, vals: &vals_of_type[seen_tp]}).collect(), Some(return_tp))).unwrap_or((vec![],None));
305
306        ArgChoiceIterator {
307            args,
308            arg_tp_iter: Box::new(arg_tp_iter),
309            vals_of_type,
310            return_tp,
311            fn_cost,
312            max_cost,
313            prev_max_cost,
314            prev_idx_to_inc: 0,
315        }
316    }
317    fn next_tps(&mut self) -> bool {
318        match self.arg_tp_iter.next() {
319            Some((seen_arg_tps, return_tp)) => {
320                for (arg, (seen_tp,_)) in self.args.iter_mut().zip(seen_arg_tps.iter()) {
321                    arg.i_vals = 0;
322                    arg.tp = seen_tp;
323                    arg.vals = &self.vals_of_type[seen_tp];
324                }
325                self.return_tp = Some(return_tp);
326                true
327            },
328            None => {
329                self.return_tp = None;
330                false
331            },
332        }
333    }
334    fn rollover(&mut self) {
335        let mut carry = false;
336        for (i,arg) in self.args.iter_mut().enumerate() {
337            if carry {
338                arg.i_vals += 1; // carry the +1
339                self.prev_idx_to_inc = i;
340                carry = false;
341            }
342            if arg.i_vals >= arg.vals.len() {
343                arg.i_vals = 0;
344                carry = true;
345            }
346        }
347        if carry {
348            self.args.last_mut().unwrap().i_vals = self.args.last().unwrap().vals.len();
349        }
350    }
351}
352
353
354impl<'a, D: Domain> Iterator for ArgChoiceIterator<'a, D> {
355    type Item = (Vec<&'a Found<D>>, Type, usize);
356
357    fn next(&mut self) -> Option<Self::Item> {
358        if self.return_tp == None {
359            return None // this is a sentinel
360        }
361
362        loop {
363            // termination condition / set new types
364            if self.args.last().unwrap().i_vals >= self.args.last().unwrap().vals.len() {
365                if self.next_tps() {
366                    continue
367                } else {
368                    return None;
369                }
370            }
371
372            // check the cost, and if its too high then max out whatever the last thing
373            // to increment was (which we know has all zeros to the left of it bc thats
374            // how incrementing happens) so that the thing one higher than it will get incremented
375            let cost: usize = self.fn_cost + self.args.iter().map(|arg| arg.vals[arg.i_vals].cost).sum::<usize>();
376            
377            if cost > self.max_cost {
378                // skip ahead off the end bc we know theyll all be too expensive.
379                self.args[self.prev_idx_to_inc].i_vals = self.args[self.prev_idx_to_inc].vals.len();
380                debug_assert!(self.args[..self.prev_idx_to_inc].iter().all(|arg| arg.i_vals == 0));
381                self.rollover();
382                continue;
383            }
384
385            // check if cost is somethign we could have caught on a previous iteration
386            if cost <= self.prev_max_cost {
387                self.args.first_mut().unwrap().i_vals += 1;
388                self.prev_idx_to_inc = 0;
389                self.rollover();
390                continue;
391            }
392
393            let res: Vec<&Found<D>> = self.args.iter().map(|arg| &arg.vals[arg.i_vals]).collect();
394
395            // just increment the base
396            self.args.first_mut().unwrap().i_vals += 1;
397            self.prev_idx_to_inc = 0;
398            self.rollover();
399
400
401            return Some((res, self.return_tp.clone().unwrap(), cost))
402        }
403    }
404}