Skip to main content

prune_lang/interp/
runner.rs

1use super::config::{RunnerConfig, RunnerStats};
2use super::solver;
3use super::strategy::*;
4use super::*;
5use crate::cli::args::{self, CliArgs};
6use crate::cli::pipeline::PipeIO;
7
8pub struct RunnerState<'prog, 'io> {
9    prog: &'prog Program,
10    pipe_io: &'io mut PipeIO,
11    config: RunnerConfig,
12    stats: RunnerStats,
13    ctx_cnt: usize,
14    ansr_cnt: usize,
15    rng: rngs::ThreadRng,
16    stack: Vec<Branch>,
17    solver: Box<dyn solver::common::PrimSolver>,
18}
19
20impl<'prog, 'io> RunnerState<'prog, 'io> {
21    pub fn new(
22        prog: &'prog Program,
23        pipe: &'io mut PipeIO,
24        args: &CliArgs,
25    ) -> RunnerState<'prog, 'io> {
26        let solver_obj: Box<dyn solver::common::PrimSolver> = match args.solver {
27            args::Solver::Z3 => Box::new(super::solver::smtlib::SmtLibSolver::new(
28                super::solver::smtlib::SolverBackend::Z3,
29            )),
30            args::Solver::CVC5 => Box::new(super::solver::smtlib::SmtLibSolver::new(
31                super::solver::smtlib::SolverBackend::CVC5,
32            )),
33            args::Solver::NoSmt => Box::new(super::solver::no_smt::NoSmtSolver::new()),
34        };
35
36        let rng = rand::rng();
37
38        RunnerState {
39            prog,
40            pipe_io: pipe,
41            config: RunnerConfig::new(args),
42            stats: RunnerStats::new(),
43            ctx_cnt: 0,
44            ansr_cnt: 0,
45            rng,
46            stack: Vec::new(),
47            solver: solver_obj,
48        }
49    }
50
51    pub fn config_set_param(&mut self, param: &QueryParam) {
52        self.config.set_param(param);
53    }
54
55    fn reset(&mut self) {
56        self.stats.reset();
57        assert!(self.stack.is_empty());
58        self.ctx_cnt = 0;
59    }
60
61    fn init_stack(&mut self, pred: Ident) {
62        // predicate for query can not be polymorphic!
63        assert!(self.prog.preds[&pred].polys.is_empty());
64
65        self.ctx_cnt = 0;
66        let pars: Vec<Ident> = self.prog.preds[&pred]
67            .pars
68            .iter()
69            .map(|(par, _typ)| *par)
70            .collect();
71
72        let rules = &self.prog.preds[&pred].rules;
73        let mut call = PredCall {
74            pred,
75            polys: Vec::new(),
76            args: pars.iter().map(|par| Term::Var(par.tag_ctx(0))).collect(),
77            looks: (0..rules.len()).collect(),
78            history: History::new(),
79        };
80
81        if self.config.heuristic == args::Heuristic::LookAhead {
82            self.stats.step_la();
83            call.lookahead_update(rules);
84        }
85
86        let brch = Branch {
87            depth: 0,
88            answers: pars
89                .iter()
90                .map(|par| (*par, Term::Var(par.tag_ctx(0))))
91                .collect(),
92            prims: Vec::new(),
93            calls: vec![call],
94        };
95
96        self.stack.push(brch);
97    }
98
99    fn run_dfs_with_depth(&mut self, depth_start: usize, depth_end: usize) {
100        while let Some(mut brch) = self.stack.pop() {
101            if self.config.debug_mode {
102                println!("{brch}");
103
104                // pause to wait for any input
105                let mut s = String::new();
106                std::io::stdin().read_line(&mut s).unwrap();
107            }
108
109            if self.ansr_cnt >= self.config.answer_limit {
110                return;
111            }
112            assert!(brch.depth <= depth_end);
113            if brch.calls.is_empty() {
114                if brch.depth >= depth_start {
115                    self.solve_answer(&brch);
116                }
117            } else if brch.depth + brch.calls.len() <= depth_end {
118                self.run_branch_step(&mut brch);
119            }
120        }
121    }
122
123    fn solve_answer(&mut self, brch: &Branch) {
124        let start = std::time::Instant::now();
125
126        if let Some(map) = self.solver.check_sat(&brch.prims) {
127            let duration = start.elapsed();
128            writeln!(
129                self.pipe_io.output,
130                "[ANSWER]: depth = {}, solving time = {:?}",
131                brch.depth, duration
132            )
133            .unwrap();
134
135            let map = map
136                .into_iter()
137                .map(|(var, lit)| (var, Term::Lit(lit)))
138                .collect();
139
140            for (par, val) in &brch.answers {
141                writeln!(self.pipe_io.output, "{} = {}", par, val.substitute(&map)).unwrap();
142            }
143            self.ansr_cnt += 1;
144        }
145    }
146
147    fn run_branch_step(&mut self, brch: &mut Branch) {
148        let call_idx = match self.config.heuristic {
149            args::Heuristic::LeftBiased => brch.left_biased_strategy(),
150            args::Heuristic::Interleave => brch.naive_strategy(1),
151            args::Heuristic::StructRecur => brch.struct_recur_strategy(),
152            args::Heuristic::LookAhead => brch.lookahead_strategy(),
153            args::Heuristic::Random => brch.random_strategy(&mut self.rng),
154        };
155
156        use rand::seq::SliceRandom;
157        let mut looks = brch.calls[call_idx].looks.clone();
158        looks.shuffle(&mut self.rng);
159
160        for rule_idx in looks.iter().rev() {
161            self.stats.step();
162            self.ctx_cnt += 1;
163            if let Ok(new_brch) = self.apply_rule(brch, call_idx, *rule_idx) {
164                self.stack.push(new_brch);
165            }
166        }
167    }
168
169    fn apply_rule(
170        &mut self,
171        brch: &Branch,
172        call_idx: usize,
173        rule_idx: usize,
174    ) -> Result<Branch, ()> {
175        let rules = &self.prog.preds[&brch.calls[call_idx].pred].rules;
176        let rule_ctx = rules[rule_idx].tag_ctx(self.ctx_cnt);
177
178        let call = &brch.calls[call_idx];
179        assert_eq!(rule_ctx.head.len(), call.args.len());
180
181        let mut unifier: Unifier<IdentCtx, LitVal, OptCons<Ident>> = Unifier::new();
182        for (par, arg) in rule_ctx.head.iter().zip(call.args.iter()) {
183            if unifier.unify(par, arg).is_err() {
184                return Err(());
185            }
186        }
187
188        let mut new_brch = brch.clone();
189        new_brch.depth += 1;
190        new_brch.remove(call_idx);
191
192        for (prim, args) in &rule_ctx.prims {
193            new_brch.prims.push((*prim, args.clone()));
194        }
195
196        if !super::progagate::propagate_unify(&mut new_brch.prims, &mut unifier) {
197            return Err(());
198        }
199
200        let mut new_history = call.history.clone();
201        new_history.push(
202            call.pred,
203            call.args.iter().map(|arg| arg.height()).collect(),
204        );
205
206        for (pred, polys, args) in rule_ctx.calls.iter().rev() {
207            let mut new_call = PredCall {
208                pred: *pred,
209                polys: polys.clone(),
210                args: args.clone(),
211                looks: (0..self.prog.preds[pred].rules.len()).collect(),
212                history: new_history.clone(),
213            };
214
215            if self.config.heuristic == args::Heuristic::LookAhead {
216                self.stats.step_la();
217                new_call.lookahead_update(&self.prog.preds[pred].rules);
218            }
219
220            new_brch.insert(call_idx, new_call);
221        }
222
223        for call in &mut new_brch.calls {
224            let mut dirty_flag = false;
225            for arg in &mut call.args {
226                if let Some(new_arg) = unifier.subst_opt(arg) {
227                    *arg = new_arg;
228                    dirty_flag = true;
229                }
230            }
231            // update lookahead information if any information is propagated
232            if dirty_flag && self.config.heuristic == args::Heuristic::LookAhead {
233                self.stats.step_la();
234                call.lookahead_update(&self.prog.preds[&call.pred].rules);
235            }
236        }
237
238        for (_par, val) in &mut new_brch.answers {
239            *val = unifier.subst(val);
240        }
241
242        Ok(new_brch)
243    }
244
245    pub fn run_iddfs_loop(&mut self, entry: Ident) -> usize {
246        for depth_limit in
247            (self.config.depth_step..=self.config.depth_limit).step_by(self.config.depth_step)
248        {
249            writeln!(
250                self.pipe_io.stat,
251                "[RUN]: try depth = {}... (found answer: {})",
252                depth_limit, self.ansr_cnt
253            )
254            .unwrap();
255
256            self.reset();
257            self.init_stack(entry);
258            self.run_dfs_with_depth(depth_limit - self.config.depth_step + 1, depth_limit);
259
260            let stat_res = self.stats.print_stat();
261            writeln!(self.pipe_io.stat, "{stat_res}").unwrap();
262
263            if self.ansr_cnt >= self.config.answer_limit {
264                return self.ansr_cnt;
265            }
266        }
267        self.ansr_cnt
268    }
269}
270
271#[test]
272fn test_runner() {
273    let src: &'static str = r#"
274datatype IntList where
275| Cons(Int, IntList)
276| Nil
277end
278
279function append(xs: IntList, x: Int) -> IntList
280begin
281    match xs with
282    | Cons(head, tail) => Cons(head, append(tail, x))
283    | Nil => Cons(x, Nil)
284    end
285end
286
287function is_elem(xs: IntList, x: Int) -> Bool
288begin
289    match xs with
290    | Cons(head, tail) => if head == x then true else is_elem(tail, x) 
291    | Nil => false
292    end
293end
294
295function is_elem_after_append(xs: IntList, x: Int) -> Bool
296begin
297    guard is_elem(append(xs, x), x) = false;
298    true
299end
300
301query is_elem_after_append(depth_step=5, depth_limit=50, answer_limit=100)
302    "#;
303
304    let (mut prog, errs) = crate::syntax::parser::parse_program(src);
305    assert!(errs.is_empty());
306
307    let errs = crate::tych::rename::rename_pass(&mut prog);
308    assert!(errs.is_empty());
309
310    let errs = crate::tych::check::check_pass(&prog);
311    assert!(errs.is_empty());
312
313    let mut prog = crate::logic::compile::compile_pass(&prog);
314    crate::logic::elaborate::elaborate_pass(&mut prog);
315
316    // println!("{:#?}", prog);
317
318    let mut pipe_io = PipeIO::empty();
319    let mut runner = RunnerState::new(
320        &prog,
321        &mut pipe_io,
322        &args::get_test_cli_args(std::path::PathBuf::new()),
323    );
324    let query = &prog.querys[0];
325
326    for param in query.params.iter() {
327        runner.config_set_param(param);
328    }
329    runner.run_iddfs_loop(query.entry);
330}