Skip to main content

prune_lang/interp/
runner.rs

1use super::config::{RunnerConfig, RunnerStats};
2use super::solver;
3use super::strategy::*;
4use super::*;
5use crate::cli::args;
6use crate::cli::pipeline::PipeIO;
7
8pub struct RunnerState<'prog, 'io> {
9    prog: &'prog Program,
10    pipe_io: &'io mut PipeIO,
11    config: RunnerConfig,
12    stats: RunnerStats,
13    ctx_cnt: usize,
14    ansr_cnt: usize,
15    stack: Vec<Branch>,
16    solver: Box<dyn solver::common::PrimSolver>,
17}
18
19impl<'prog, 'io> RunnerState<'prog, 'io> {
20    pub fn new(
21        prog: &'prog Program,
22        pipe: &'io mut PipeIO,
23        solver: args::Solver,
24        heuristic: args::Heuristic,
25        debug_mode: bool,
26    ) -> RunnerState<'prog, 'io> {
27        let solver_obj: Box<dyn solver::common::PrimSolver> = match solver {
28            args::Solver::Z3 => Box::new(super::solver::smtlib::SmtLibSolver::new(
29                super::solver::smtlib::SolverBackend::Z3,
30            )),
31            args::Solver::CVC5 => Box::new(super::solver::smtlib::SmtLibSolver::new(
32                super::solver::smtlib::SolverBackend::CVC5,
33            )),
34            args::Solver::NoSmt => Box::new(super::solver::no_smt::NoSmtSolver::new()),
35        };
36
37        RunnerState {
38            prog,
39            pipe_io: pipe,
40            config: RunnerConfig::new(solver, heuristic, debug_mode),
41            stats: RunnerStats::new(),
42            ctx_cnt: 0,
43            ansr_cnt: 0,
44            stack: Vec::new(),
45            solver: solver_obj,
46        }
47    }
48
49    pub fn config_set_param(&mut self, param: &QueryParam) {
50        self.config.set_param(param);
51    }
52
53    fn reset(&mut self) {
54        self.stats.reset();
55        assert!(self.stack.is_empty());
56        self.ctx_cnt = 0;
57    }
58
59    pub fn run_dfs_with_depth(&mut self, pred: Ident, depth_start: usize, depth_end: usize) {
60        // predicate for query can not be polymorphic!
61        assert!(self.prog.preds[&pred].polys.is_empty());
62
63        self.ctx_cnt = 0;
64        let pars: Vec<Ident> = self.prog.preds[&pred]
65            .pars
66            .iter()
67            .map(|(par, _typ)| *par)
68            .collect();
69
70        let rules = &self.prog.preds[&pred].rules;
71        let mut call = PredCall {
72            pred,
73            polys: Vec::new(),
74            args: pars.iter().map(|par| Term::Var(par.tag_ctx(0))).collect(),
75            looks: (0..rules.len()).collect(),
76            history: History::new(),
77        };
78
79        if self.config.heuristic == args::Heuristic::LookAhead {
80            self.stats.step_la();
81            call.lookahead_update(rules);
82        }
83
84        let brch = Branch {
85            depth: 0,
86            answers: pars
87                .iter()
88                .map(|par| (*par, Term::Var(par.tag_ctx(0))))
89                .collect(),
90            prims: Vec::new(),
91            calls: vec![call],
92        };
93
94        self.stack.push(brch);
95
96        while let Some(mut brch) = self.stack.pop() {
97            if self.config.debug_mode {
98                println!("{}", brch);
99
100                // pause to wait for any input
101                let mut s = String::new();
102                std::io::stdin().read_line(&mut s).unwrap();
103            }
104
105            if self.ansr_cnt >= self.config.answer_limit {
106                return;
107            }
108            assert!(brch.depth <= depth_end);
109            if brch.calls.is_empty() {
110                if brch.depth >= depth_start {
111                    self.solve_answer(&brch);
112                }
113            } else if brch.depth + brch.calls.len() <= depth_end {
114                self.run_branch_step(&mut brch);
115            }
116        }
117    }
118
119    fn solve_answer(&mut self, brch: &Branch) {
120        let start = std::time::Instant::now();
121
122        if let Some(map) = self.solver.check_sat(&brch.prims) {
123            let duration = start.elapsed();
124            writeln!(
125                self.pipe_io.output,
126                "[ANSWER]: depth = {}, solving time = {:?}",
127                brch.depth, duration
128            )
129            .unwrap();
130
131            let map = map
132                .into_iter()
133                .map(|(var, lit)| (var, Term::Lit(lit)))
134                .collect();
135
136            for (par, val) in brch.answers.iter() {
137                writeln!(self.pipe_io.output, "{} = {}", par, val.substitute(&map)).unwrap();
138            }
139            self.ansr_cnt += 1;
140        }
141    }
142
143    fn run_branch_step(&mut self, brch: &mut Branch) {
144        let call_idx = match self.config.heuristic {
145            args::Heuristic::LeftBiased => brch.left_biased_strategy(),
146            args::Heuristic::Interleave => brch.naive_strategy(1),
147            args::Heuristic::StructRecur => brch.struct_recur_strategy(),
148            args::Heuristic::LookAhead => brch.lookahead_strategy(),
149            args::Heuristic::Random => brch.random_strategy(),
150        };
151
152        for rule_idx in brch.calls[call_idx].looks.iter().rev() {
153            self.stats.step();
154            self.ctx_cnt += 1;
155            if let Ok(new_brch) = self.apply_rule(brch, call_idx, *rule_idx) {
156                self.stack.push(new_brch);
157            }
158        }
159    }
160
161    fn apply_rule(
162        &mut self,
163        brch: &Branch,
164        call_idx: usize,
165        rule_idx: usize,
166    ) -> Result<Branch, ()> {
167        let rules = &self.prog.preds[&brch.calls[call_idx].pred].rules;
168        let rule_ctx = rules[rule_idx].tag_ctx(self.ctx_cnt);
169
170        let call = &brch.calls[call_idx];
171        assert_eq!(rule_ctx.head.len(), call.args.len());
172
173        let mut unifier: Unifier<IdentCtx, LitVal, OptCons<Ident>> = Unifier::new();
174        for (par, arg) in rule_ctx.head.iter().zip(call.args.iter()) {
175            if unifier.unify(par, arg).is_err() {
176                return Err(());
177            }
178        }
179
180        let mut new_brch = brch.clone();
181        new_brch.depth += 1;
182        new_brch.remove(call_idx);
183
184        for (prim, args) in rule_ctx.prims.iter() {
185            new_brch.prims.push((*prim, args.clone()));
186        }
187
188        if !super::progagate::propagate_unify(&mut new_brch.prims, &mut unifier) {
189            return Err(());
190        }
191
192        let mut new_history = call.history.clone();
193        new_history.push(
194            call.pred,
195            call.args.iter().map(|arg| arg.height()).collect(),
196        );
197
198        for (pred, polys, args) in rule_ctx.calls.iter().rev() {
199            let mut new_call = PredCall {
200                pred: *pred,
201                polys: polys.clone(),
202                args: args.clone(),
203                looks: (0..self.prog.preds[pred].rules.len()).collect(),
204                history: new_history.clone(),
205            };
206
207            if self.config.heuristic == args::Heuristic::LookAhead {
208                self.stats.step_la();
209                new_call.lookahead_update(&self.prog.preds[pred].rules);
210            }
211
212            new_brch.insert(call_idx, new_call);
213        }
214
215        for call in new_brch.calls.iter_mut() {
216            let mut dirty_flag = false;
217            for arg in call.args.iter_mut() {
218                if let Some(new_arg) = unifier.subst_opt(arg) {
219                    *arg = new_arg;
220                    dirty_flag = true;
221                }
222            }
223            // update lookahead information if any information is propagated
224            if dirty_flag && self.config.heuristic == args::Heuristic::LookAhead {
225                self.stats.step_la();
226                call.lookahead_update(&self.prog.preds[&call.pred].rules);
227            }
228        }
229
230        for (_par, val) in new_brch.answers.iter_mut() {
231            *val = unifier.subst(val);
232        }
233
234        Ok(new_brch)
235    }
236
237    pub fn run_iddfs_loop(&mut self, entry: Ident) -> usize {
238        for depth_limit in
239            (self.config.depth_step..=self.config.depth_limit).step_by(self.config.depth_step)
240        {
241            writeln!(
242                self.pipe_io.stat,
243                "[RUN]: try depth = {}... (found answer: {})",
244                depth_limit, self.ansr_cnt
245            )
246            .unwrap();
247
248            self.reset();
249
250            self.run_dfs_with_depth(entry, depth_limit - self.config.depth_step + 1, depth_limit);
251
252            let stat_res = self.stats.print_stat();
253            writeln!(self.pipe_io.stat, "{}", stat_res).unwrap();
254
255            if self.ansr_cnt >= self.config.answer_limit {
256                return self.ansr_cnt;
257            }
258        }
259        self.ansr_cnt
260    }
261}
262
263#[test]
264fn test_runner() {
265    let src: &'static str = r#"
266datatype IntList where
267| Cons(Int, IntList)
268| Nil
269end
270
271function append(xs: IntList, x: Int) -> IntList
272begin
273    match xs with
274    | Cons(head, tail) => Cons(head, append(tail, x))
275    | Nil => Cons(x, Nil)
276    end
277end
278
279function is_elem(xs: IntList, x: Int) -> Bool
280begin
281    match xs with
282    | Cons(head, tail) => if head == x then true else is_elem(tail, x) 
283    | Nil => false
284    end
285end
286
287function is_elem_after_append(xs: IntList, x: Int) -> Bool
288begin
289    guard is_elem(append(xs, x), x) = false;
290    true
291end
292
293query is_elem_after_append(depth_step=5, depth_limit=50, answer_limit=100)
294    "#;
295
296    let (mut prog, errs) = crate::syntax::parser::parse_program(&src);
297    assert!(errs.is_empty());
298
299    let errs = crate::tych::rename::rename_pass(&mut prog);
300    assert!(errs.is_empty());
301
302    let errs = crate::tych::check::check_pass(&prog);
303    assert!(errs.is_empty());
304
305    let mut prog = crate::logic::compile::compile_pass(&prog);
306    crate::logic::elaborate::elaborate_pass(&mut prog);
307
308    // println!("{:#?}", prog);
309
310    let mut pipe_io = PipeIO::empty();
311    let mut runner = RunnerState::new(
312        &prog,
313        &mut pipe_io,
314        args::Solver::Z3,
315        args::Heuristic::Interleave,
316        false,
317    );
318    let query = &prog.querys[0];
319
320    for param in query.params.iter() {
321        runner.config_set_param(param);
322    }
323    runner.run_iddfs_loop(query.entry);
324}