1use super::config::{RunnerConfig, RunnerStats};
2use super::solver;
3use super::strategy::*;
4use super::*;
5use crate::cli::args::{self, CliArgs};
6use crate::cli::pipeline::PipeIO;
7
8pub struct RunnerState<'prog, 'io> {
9 prog: &'prog Program,
10 pipe_io: &'io mut PipeIO,
11 config: RunnerConfig,
12 stats: RunnerStats,
13 ctx_cnt: usize,
14 ansr_cnt: usize,
15 rng: rngs::ThreadRng,
16 stack: Vec<Branch>,
17 solver: Box<dyn solver::common::PrimSolver>,
18}
19
20impl<'prog, 'io> RunnerState<'prog, 'io> {
21 pub fn new(
22 prog: &'prog Program,
23 pipe: &'io mut PipeIO,
24 args: &CliArgs,
25 ) -> RunnerState<'prog, 'io> {
26 let solver_obj: Box<dyn solver::common::PrimSolver> = match args.solver {
27 args::Solver::Z3 => Box::new(super::solver::smtlib::SmtLibSolver::new(
28 super::solver::smtlib::SolverBackend::Z3,
29 )),
30 args::Solver::CVC5 => Box::new(super::solver::smtlib::SmtLibSolver::new(
31 super::solver::smtlib::SolverBackend::CVC5,
32 )),
33 args::Solver::NoSmt => Box::new(super::solver::no_smt::NoSmtSolver::new()),
34 };
35
36 let rng = rand::rng();
37
38 RunnerState {
39 prog,
40 pipe_io: pipe,
41 config: RunnerConfig::new(args),
42 stats: RunnerStats::new(),
43 ctx_cnt: 0,
44 ansr_cnt: 0,
45 rng,
46 stack: Vec::new(),
47 solver: solver_obj,
48 }
49 }
50
51 pub fn config_set_param(&mut self, param: &QueryParam) {
52 self.config.set_param(param);
53 }
54
55 fn reset(&mut self) {
56 self.stats.reset();
57 assert!(self.stack.is_empty());
58 self.ctx_cnt = 0;
59 }
60
61 fn init_stack(&mut self, pred: Ident) {
62 assert!(self.prog.preds[&pred].polys.is_empty());
64
65 self.ctx_cnt = 0;
66 let pars: Vec<Ident> = self.prog.preds[&pred]
67 .pars
68 .iter()
69 .map(|(par, _typ)| *par)
70 .collect();
71
72 let rules = &self.prog.preds[&pred].rules;
73 let mut call = PredCall {
74 pred,
75 polys: Vec::new(),
76 args: pars.iter().map(|par| Term::Var(par.tag_ctx(0))).collect(),
77 looks: (0..rules.len()).collect(),
78 history: History::new(),
79 };
80
81 if self.config.heuristic == args::Heuristic::LookAhead {
82 self.stats.step_la();
83 call.lookahead_update(rules);
84 }
85
86 let brch = Branch {
87 depth: 0,
88 answers: pars
89 .iter()
90 .map(|par| (*par, Term::Var(par.tag_ctx(0))))
91 .collect(),
92 prims: Vec::new(),
93 calls: vec![call],
94 };
95
96 self.stack.push(brch);
97 }
98
99 fn run_dfs_with_depth(&mut self, depth_start: usize, depth_end: usize) {
100 while let Some(mut brch) = self.stack.pop() {
101 if self.config.debug_mode {
102 println!("{brch}");
103
104 let mut s = String::new();
106 std::io::stdin().read_line(&mut s).unwrap();
107 }
108
109 if self.ansr_cnt >= self.config.answer_limit {
110 return;
111 }
112 assert!(brch.depth <= depth_end);
113 if brch.calls.is_empty() {
114 if brch.depth >= depth_start {
115 self.solve_answer(&brch);
116 }
117 } else if brch.depth + brch.calls.len() <= depth_end {
118 self.run_branch_step(&mut brch);
119 }
120 }
121 }
122
123 fn solve_answer(&mut self, brch: &Branch) {
124 let start = std::time::Instant::now();
125
126 if let Some(map) = self.solver.check_sat(&brch.prims) {
127 let duration = start.elapsed();
128 writeln!(
129 self.pipe_io.output,
130 "[ANSWER]: depth = {}, solving time = {:?}",
131 brch.depth, duration
132 )
133 .unwrap();
134
135 let map = map
136 .into_iter()
137 .map(|(var, lit)| (var, Term::Lit(lit)))
138 .collect();
139
140 for (par, val) in &brch.answers {
141 writeln!(self.pipe_io.output, "{} = {}", par, val.substitute(&map)).unwrap();
142 }
143 self.ansr_cnt += 1;
144 }
145 }
146
147 fn run_branch_step(&mut self, brch: &mut Branch) {
148 let call_idx = match self.config.heuristic {
149 args::Heuristic::LeftBiased => brch.left_biased_strategy(),
150 args::Heuristic::Interleave => brch.naive_strategy(1),
151 args::Heuristic::StructRecur => brch.struct_recur_strategy(),
152 args::Heuristic::LookAhead => brch.lookahead_strategy(),
153 args::Heuristic::Random => brch.random_strategy(&mut self.rng),
154 };
155
156 use rand::seq::SliceRandom;
157 let mut looks = brch.calls[call_idx].looks.clone();
158 looks.shuffle(&mut self.rng);
159
160 for rule_idx in looks.iter().rev() {
161 self.stats.step();
162 self.ctx_cnt += 1;
163 if let Ok(new_brch) = self.apply_rule(brch, call_idx, *rule_idx) {
164 self.stack.push(new_brch);
165 }
166 }
167 }
168
169 fn apply_rule(
170 &mut self,
171 brch: &Branch,
172 call_idx: usize,
173 rule_idx: usize,
174 ) -> Result<Branch, ()> {
175 let rules = &self.prog.preds[&brch.calls[call_idx].pred].rules;
176 let rule_ctx = rules[rule_idx].tag_ctx(self.ctx_cnt);
177
178 let call = &brch.calls[call_idx];
179 assert_eq!(rule_ctx.head.len(), call.args.len());
180
181 let mut unifier: Unifier<IdentCtx, LitVal, OptCons<Ident>> = Unifier::new();
182 for (par, arg) in rule_ctx.head.iter().zip(call.args.iter()) {
183 if unifier.unify(par, arg).is_err() {
184 return Err(());
185 }
186 }
187
188 let mut new_brch = brch.clone();
189 new_brch.depth += 1;
190 new_brch.remove(call_idx);
191
192 for (prim, args) in &rule_ctx.prims {
193 new_brch.prims.push((*prim, args.clone()));
194 }
195
196 if !super::progagate::propagate_unify(&mut new_brch.prims, &mut unifier) {
197 return Err(());
198 }
199
200 let mut new_history = call.history.clone();
201 new_history.push(
202 call.pred,
203 call.args.iter().map(|arg| arg.height()).collect(),
204 );
205
206 for (pred, polys, args) in rule_ctx.calls.iter().rev() {
207 let mut new_call = PredCall {
208 pred: *pred,
209 polys: polys.clone(),
210 args: args.clone(),
211 looks: (0..self.prog.preds[pred].rules.len()).collect(),
212 history: new_history.clone(),
213 };
214
215 if self.config.heuristic == args::Heuristic::LookAhead {
216 self.stats.step_la();
217 new_call.lookahead_update(&self.prog.preds[pred].rules);
218 }
219
220 new_brch.insert(call_idx, new_call);
221 }
222
223 for call in &mut new_brch.calls {
224 let mut dirty_flag = false;
225 for arg in &mut call.args {
226 if let Some(new_arg) = unifier.subst_opt(arg) {
227 *arg = new_arg;
228 dirty_flag = true;
229 }
230 }
231 if dirty_flag && self.config.heuristic == args::Heuristic::LookAhead {
233 self.stats.step_la();
234 call.lookahead_update(&self.prog.preds[&call.pred].rules);
235 }
236 }
237
238 for (_par, val) in &mut new_brch.answers {
239 *val = unifier.subst(val);
240 }
241
242 Ok(new_brch)
243 }
244
245 pub fn run_iddfs_loop(&mut self, entry: Ident) -> usize {
246 for depth_limit in
247 (self.config.depth_step..=self.config.depth_limit).step_by(self.config.depth_step)
248 {
249 writeln!(
250 self.pipe_io.stat,
251 "[RUN]: try depth = {}... (found answer: {})",
252 depth_limit, self.ansr_cnt
253 )
254 .unwrap();
255
256 self.reset();
257 self.init_stack(entry);
258 self.run_dfs_with_depth(depth_limit - self.config.depth_step + 1, depth_limit);
259
260 let stat_res = self.stats.print_stat();
261 writeln!(self.pipe_io.stat, "{stat_res}").unwrap();
262
263 if self.ansr_cnt >= self.config.answer_limit {
264 return self.ansr_cnt;
265 }
266 }
267 self.ansr_cnt
268 }
269}
270
271#[test]
272fn test_runner() {
273 let src: &'static str = r#"
274datatype IntList where
275| Cons(Int, IntList)
276| Nil
277end
278
279function append(xs: IntList, x: Int) -> IntList
280begin
281 match xs with
282 | Cons(head, tail) => Cons(head, append(tail, x))
283 | Nil => Cons(x, Nil)
284 end
285end
286
287function is_elem(xs: IntList, x: Int) -> Bool
288begin
289 match xs with
290 | Cons(head, tail) => if head == x then true else is_elem(tail, x)
291 | Nil => false
292 end
293end
294
295function is_elem_after_append(xs: IntList, x: Int) -> Bool
296begin
297 guard is_elem(append(xs, x), x) = false;
298 true
299end
300
301query is_elem_after_append(depth_step=5, depth_limit=50, answer_limit=100)
302 "#;
303
304 let (mut prog, errs) = crate::syntax::parser::parse_program(src);
305 assert!(errs.is_empty());
306
307 let errs = crate::tych::rename::rename_pass(&mut prog);
308 assert!(errs.is_empty());
309
310 let errs = crate::tych::check::check_pass(&prog);
311 assert!(errs.is_empty());
312
313 let mut prog = crate::logic::compile::compile_pass(&prog);
314 crate::logic::elaborate::elaborate_pass(&mut prog);
315
316 let mut pipe_io = PipeIO::empty();
319 let mut runner = RunnerState::new(
320 &prog,
321 &mut pipe_io,
322 &args::get_test_cli_args(std::path::PathBuf::new()),
323 );
324 let query = &prog.querys[0];
325
326 for param in query.params.iter() {
327 runner.config_set_param(param);
328 }
329 runner.run_iddfs_loop(query.entry);
330}