use super::config::{RunnerConfig, RunnerStats};
use super::solver;
use super::strategy::*;
use super::*;
use crate::cli::args::{self, CliArgs};
use crate::cli::pipeline::OutputWriter;
pub struct RunnerState<'prog, 'io> {
prog: &'prog Program,
output: &'io mut OutputWriter,
config: RunnerConfig,
stats: RunnerStats,
ctx_cnt: usize,
ansr_cnt: usize,
rng: rngs::ThreadRng,
stack: Vec<Branch>,
solver: Box<dyn solver::common::PrimSolver>,
}
impl<'prog, 'io> RunnerState<'prog, 'io> {
pub fn new(
prog: &'prog Program,
output: &'io mut OutputWriter,
args: &CliArgs,
) -> RunnerState<'prog, 'io> {
let solver_obj: Box<dyn solver::common::PrimSolver> = match args.solver {
args::Solver::Z3 => Box::new(super::solver::smtlib::SmtLibSolver::new(
super::solver::smtlib::SolverBackend::Z3,
)),
args::Solver::CVC5 => Box::new(super::solver::smtlib::SmtLibSolver::new(
super::solver::smtlib::SolverBackend::CVC5,
)),
args::Solver::NoSmt => Box::new(super::solver::no_smt::NoSmtSolver::new()),
};
let rng = rand::rng();
RunnerState {
prog,
output,
config: RunnerConfig::new(args),
stats: RunnerStats::new(),
ctx_cnt: 0,
ansr_cnt: 0,
rng,
stack: Vec::new(),
solver: solver_obj,
}
}
pub fn config_set_param(&mut self, param: &QueryParam) {
self.config.set_param(param);
}
fn reset(&mut self) {
self.stats.reset();
assert!(self.stack.is_empty());
self.ctx_cnt = 0;
}
fn init_stack(&mut self, pred: Ident) {
assert!(self.prog.preds[&pred].polys.is_empty());
self.ctx_cnt = 0;
let pars: Vec<Ident> = self.prog.preds[&pred]
.pars
.iter()
.map(|(par, _typ)| *par)
.collect();
let rules = &self.prog.preds[&pred].rules;
let mut call = PredCall {
pred,
polys: Vec::new(),
args: pars.iter().map(|par| Term::Var(par.tag_ctx(0))).collect(),
looks: (0..rules.len()).collect(),
depth: 0,
};
call.lookahead_update(rules);
let brch = Branch {
depth: 0,
answers: pars
.iter()
.map(|par| (*par, Term::Var(par.tag_ctx(0))))
.collect(),
prims: Vec::new(),
calls: vec![call],
};
self.stack.push(brch);
}
fn run_dfs_with_depth(&mut self, depth_start: usize, depth_end: usize) {
while let Some(mut brch) = self.stack.pop() {
if self.config.debug_mode {
println!("{brch}");
let mut s = String::new();
std::io::stdin().read_line(&mut s).unwrap();
}
if self.ansr_cnt >= self.config.answer_limit {
return;
}
if brch.calls.is_empty() {
if brch.depth >= depth_start && brch.depth <= depth_end {
self.solve_answer(&brch);
}
} else if brch.depth + brch.calls.len() <= depth_end {
self.run_branch_step(&mut brch);
}
}
}
fn solve_answer(&mut self, brch: &Branch) {
let start = std::time::Instant::now();
if let Some(map) = self.solver.check_sat(&brch.prims) {
let duration = start.elapsed();
writeln!(
self.output.answer,
"[ANSWER]: depth = {}, solving time = {:?}",
brch.depth, duration
)
.unwrap();
let map = map
.into_iter()
.map(|(var, lit)| (var, Term::Lit(lit)))
.collect();
for (par, val) in &brch.answers {
writeln!(self.output.answer, "{} = {}", par, val.substitute(&map)).unwrap();
}
self.ansr_cnt += 1;
}
}
fn run_branch_step(&mut self, brch: &mut Branch) {
let call_idx = match self.config.heuristic {
args::Heuristic::LeftBiased => brch.left_biased_strategy(),
args::Heuristic::Interleave => brch.interleave_strategy(),
args::Heuristic::SmallFirst => brch.small_first_strategy(),
args::Heuristic::LookAhead => self.lookahead_choose(brch),
args::Heuristic::Random => brch.random_strategy(&mut self.rng),
};
use rand::seq::SliceRandom;
let mut looks = brch.calls[call_idx].looks.clone();
looks.shuffle(&mut self.rng);
self.stats.step();
for rule_idx in looks.iter().rev() {
if let Some((brch, _steps)) = self.apply_rule_with_reduction(brch, call_idx, *rule_idx)
{
self.stack.push(brch);
}
}
}
fn lookahead_choose(&mut self, brch: &Branch) -> usize {
assert!(!brch.calls.is_empty());
let mut best_score: f32 = f32::MAX;
let mut best_idx: usize = 0;
let mut calls: Vec<usize> = (0..brch.calls.len()).collect();
calls.sort_by_key(|call| brch.calls[*call].looks.len());
for call_idx in calls.into_iter() {
self.stats.step_la();
let mut vec = Vec::new();
for rule_idx in brch.calls[call_idx].looks.iter().rev() {
if let Some((new_brch, steps)) =
self.apply_rule_with_reduction(brch, call_idx, *rule_idx)
&& !new_brch.calls.is_empty()
{
vec.push(steps);
}
}
let tau = tau_function(&vec);
if tau < 1.2 {
return call_idx;
}
let score = tau + (brch.calls[call_idx].depth as f32) * (0.001_f32);
if score < best_score {
best_score = score;
best_idx = call_idx;
}
}
best_idx
}
fn apply_rule_with_reduction(
&mut self,
brch: &Branch,
call_idx: usize,
rule_idx: usize,
) -> Option<(Branch, usize)> {
const MAX_REDUCTION: usize = 10;
let mut brch = self.apply_rule(brch, call_idx, rule_idx)?;
for steps in 1..MAX_REDUCTION {
if let Some(call_idx) = brch.check_reduction() {
let looks = &brch.calls[call_idx].looks;
assert!(looks.len() <= 1);
if looks.is_empty() {
return Some((brch, steps));
} else {
brch = self.apply_rule(&brch, call_idx, brch.calls[call_idx].looks[0])?;
}
} else {
return Some((brch, steps));
}
}
Some((brch, MAX_REDUCTION))
}
fn apply_rule(&mut self, brch: &Branch, call_idx: usize, rule_idx: usize) -> Option<Branch> {
let rules = &self.prog.preds[&brch.calls[call_idx].pred].rules;
self.ctx_cnt += 1;
let rule_ctx = rules[rule_idx].tag_ctx(self.ctx_cnt);
let call = &brch.calls[call_idx];
assert_eq!(rule_ctx.head.len(), call.args.len());
let mut unifier: Unifier<IdentCtx, LitVal, OptCons<Ident>> = Unifier::new();
for (par, arg) in rule_ctx.head.iter().zip(call.args.iter()) {
if unifier.unify(par, arg).is_err() {
return None;
}
}
let mut new_brch = brch.clone();
new_brch.depth += 1;
new_brch.remove(call_idx);
for (prim, args) in &rule_ctx.prims {
new_brch.prims.push((*prim, args.clone()));
}
if !super::progagate::propagate_unify(&mut new_brch.prims, &mut unifier) {
return None;
}
for (pred, polys, args) in rule_ctx.calls.iter().rev() {
let mut new_call = PredCall {
pred: *pred,
polys: polys.clone(),
args: args.clone(),
looks: (0..self.prog.preds[pred].rules.len()).collect(),
depth: call.depth + 1,
};
new_call.lookahead_update(&self.prog.preds[pred].rules);
new_brch.insert(call_idx, new_call);
}
for call in &mut new_brch.calls {
let mut dirty_flag = false;
for arg in &mut call.args {
if let Some(new_arg) = unifier.subst_opt(arg) {
*arg = new_arg;
dirty_flag = true;
}
}
if dirty_flag {
call.lookahead_update(&self.prog.preds[&call.pred].rules);
}
}
for (_par, val) in &mut new_brch.answers {
*val = unifier.subst(val);
}
Some(new_brch)
}
pub fn run_iddfs_loop(&mut self, entry: Ident) -> usize {
for depth_limit in
(self.config.depth_step..=self.config.depth_limit).step_by(self.config.depth_step)
{
writeln!(
self.output.stat,
"[RUN]: try depth = {}... (found answer: {})",
depth_limit, self.ansr_cnt
)
.unwrap();
self.reset();
self.init_stack(entry);
self.run_dfs_with_depth(depth_limit - self.config.depth_step + 1, depth_limit);
let stat_res = self.stats.print_stat();
writeln!(self.output.stat, "{stat_res}").unwrap();
if self.ansr_cnt >= self.config.answer_limit {
return self.ansr_cnt;
}
}
self.ansr_cnt
}
}
#[test]
fn test_runner() {
let src: &'static str = r#"
datatype IntList where
| Cons(Int, IntList)
| Nil
end
function append(xs: IntList, x: Int) -> IntList
begin
match xs with
| Cons(head, tail) => Cons(head, append(tail, x))
| Nil => Cons(x, Nil)
end
end
function is_elem(xs: IntList, x: Int) -> Bool
begin
match xs with
| Cons(head, tail) => if head == x then true else is_elem(tail, x)
| Nil => false
end
end
function is_elem_after_append(xs: IntList, x: Int) -> Bool
begin
guard is_elem(append(xs, x), x) = false;
true
end
query is_elem_after_append(depth_step=5, depth_limit=50, answer_limit=100)
"#;
let (mut prog, errs) = crate::syntax::parser::parse_program(src);
assert!(errs.is_empty());
let errs = crate::tych::rename::rename_pass(&mut prog);
assert!(errs.is_empty());
let errs = crate::tych::check::check_pass(&prog);
assert!(errs.is_empty());
let mut prog = crate::logic::compile::compile_pass(&prog);
crate::logic::elaborate::elaborate_pass(&mut prog);
let mut pipe_io = OutputWriter::empty();
let mut runner = RunnerState::new(
&prog,
&mut pipe_io,
&args::get_test_cli_args(std::path::PathBuf::new()),
);
let query = &prog.querys[0];
for param in query.params.iter() {
runner.config_set_param(param);
}
runner.run_iddfs_loop(query.entry);
}