prune-lang 0.2.3

Prune is a constraint logic programming language with branching heuristic.
Documentation
use super::config::{RunnerConfig, RunnerStats};
use super::solver;
use super::strategy::*;
use super::*;
use crate::cli::args::{self, CliArgs};
use crate::cli::pipeline::OutputWriter;

pub struct RunnerState<'prog, 'io> {
    prog: &'prog Program,
    output: &'io mut OutputWriter,
    config: RunnerConfig,
    stats: RunnerStats,
    ctx_cnt: usize,
    ansr_cnt: usize,
    rng: rngs::ThreadRng,
    stack: Vec<Branch>,
    solver: Box<dyn solver::common::PrimSolver>,
}

impl<'prog, 'io> RunnerState<'prog, 'io> {
    pub fn new(
        prog: &'prog Program,
        output: &'io mut OutputWriter,
        args: &CliArgs,
    ) -> RunnerState<'prog, 'io> {
        let solver_obj: Box<dyn solver::common::PrimSolver> = match args.solver {
            args::Solver::Z3 => Box::new(super::solver::smtlib::SmtLibSolver::new(
                super::solver::smtlib::SolverBackend::Z3,
            )),
            args::Solver::CVC5 => Box::new(super::solver::smtlib::SmtLibSolver::new(
                super::solver::smtlib::SolverBackend::CVC5,
            )),
            args::Solver::NoSmt => Box::new(super::solver::no_smt::NoSmtSolver::new()),
        };

        let rng = rand::rng();

        RunnerState {
            prog,
            output,
            config: RunnerConfig::new(args),
            stats: RunnerStats::new(),
            ctx_cnt: 0,
            ansr_cnt: 0,
            rng,
            stack: Vec::new(),
            solver: solver_obj,
        }
    }

    pub fn config_set_param(&mut self, param: &QueryParam) {
        self.config.set_param(param);
    }

    fn reset(&mut self) {
        self.stats.reset();
        assert!(self.stack.is_empty());
        self.ctx_cnt = 0;
    }

    fn init_stack(&mut self, pred: Ident) {
        // predicate for query can not be polymorphic!
        assert!(self.prog.preds[&pred].polys.is_empty());

        self.ctx_cnt = 0;
        let pars: Vec<Ident> = self.prog.preds[&pred]
            .pars
            .iter()
            .map(|(par, _typ)| *par)
            .collect();

        let rules = &self.prog.preds[&pred].rules;
        let mut call = PredCall {
            pred,
            polys: Vec::new(),
            args: pars.iter().map(|par| Term::Var(par.tag_ctx(0))).collect(),
            looks: (0..rules.len()).collect(),
            depth: 0,
        };
        call.lookahead_update(rules);

        let brch = Branch {
            depth: 0,
            answers: pars
                .iter()
                .map(|par| (*par, Term::Var(par.tag_ctx(0))))
                .collect(),
            prims: Vec::new(),
            calls: vec![call],
        };

        self.stack.push(brch);
    }

    fn run_dfs_with_depth(&mut self, depth_start: usize, depth_end: usize) {
        while let Some(mut brch) = self.stack.pop() {
            if self.config.debug_mode {
                println!("{brch}");

                // pause to wait for any input
                let mut s = String::new();
                std::io::stdin().read_line(&mut s).unwrap();
            }

            if self.ansr_cnt >= self.config.answer_limit {
                return;
            }
            if brch.calls.is_empty() {
                if brch.depth >= depth_start && brch.depth <= depth_end {
                    self.solve_answer(&brch);
                }
            } else if brch.depth + brch.calls.len() <= depth_end {
                self.run_branch_step(&mut brch);
            }
        }
    }

    fn solve_answer(&mut self, brch: &Branch) {
        let start = std::time::Instant::now();

        if let Some(map) = self.solver.check_sat(&brch.prims) {
            let duration = start.elapsed();
            writeln!(
                self.output.answer,
                "[ANSWER]: depth = {}, solving time = {:?}",
                brch.depth, duration
            )
            .unwrap();

            let map = map
                .into_iter()
                .map(|(var, lit)| (var, Term::Lit(lit)))
                .collect();

            for (par, val) in &brch.answers {
                writeln!(self.output.answer, "{} = {}", par, val.substitute(&map)).unwrap();
            }
            self.ansr_cnt += 1;
        }
    }

    fn run_branch_step(&mut self, brch: &mut Branch) {
        let call_idx = match self.config.heuristic {
            args::Heuristic::LeftBiased => brch.left_biased_strategy(),
            args::Heuristic::Interleave => brch.interleave_strategy(),
            args::Heuristic::SmallFirst => brch.small_first_strategy(),
            args::Heuristic::LookAhead => self.lookahead_choose(brch),
            args::Heuristic::Random => brch.random_strategy(&mut self.rng),
        };

        use rand::seq::SliceRandom;
        let mut looks = brch.calls[call_idx].looks.clone();
        looks.shuffle(&mut self.rng);

        self.stats.step();
        for rule_idx in looks.iter().rev() {
            if let Some((brch, _steps)) = self.apply_rule_with_reduction(brch, call_idx, *rule_idx)
            {
                self.stack.push(brch);
            }
        }
    }

    fn lookahead_choose(&mut self, brch: &Branch) -> usize {
        assert!(!brch.calls.is_empty());
        let mut best_score: f32 = f32::MAX;
        let mut best_idx: usize = 0;

        let mut calls: Vec<usize> = (0..brch.calls.len()).collect();
        calls.sort_by_key(|call| brch.calls[*call].looks.len());

        for call_idx in calls.into_iter() {
            self.stats.step_la();

            let mut vec = Vec::new();
            for rule_idx in brch.calls[call_idx].looks.iter().rev() {
                if let Some((new_brch, steps)) =
                    self.apply_rule_with_reduction(brch, call_idx, *rule_idx)
                    && !new_brch.calls.is_empty()
                {
                    vec.push(steps);
                }
            }
            let tau = tau_function(&vec);
            if tau < 1.2 {
                return call_idx;
            }
            let score = tau + (brch.calls[call_idx].depth as f32) * (0.001_f32);
            if score < best_score {
                best_score = score;
                best_idx = call_idx;
            }
        }
        // println!("best_score = {}, best_idx = {}", best_score, best_idx);
        best_idx
    }

    fn apply_rule_with_reduction(
        &mut self,
        brch: &Branch,
        call_idx: usize,
        rule_idx: usize,
    ) -> Option<(Branch, usize)> {
        const MAX_REDUCTION: usize = 10;
        let mut brch = self.apply_rule(brch, call_idx, rule_idx)?;
        for steps in 1..MAX_REDUCTION {
            if let Some(call_idx) = brch.check_reduction() {
                let looks = &brch.calls[call_idx].looks;
                assert!(looks.len() <= 1);
                if looks.is_empty() {
                    return Some((brch, steps));
                } else {
                    brch = self.apply_rule(&brch, call_idx, brch.calls[call_idx].looks[0])?;
                }
            } else {
                return Some((brch, steps));
            }
        }
        Some((brch, MAX_REDUCTION))
    }

    fn apply_rule(&mut self, brch: &Branch, call_idx: usize, rule_idx: usize) -> Option<Branch> {
        let rules = &self.prog.preds[&brch.calls[call_idx].pred].rules;
        self.ctx_cnt += 1;
        let rule_ctx = rules[rule_idx].tag_ctx(self.ctx_cnt);

        let call = &brch.calls[call_idx];
        assert_eq!(rule_ctx.head.len(), call.args.len());

        let mut unifier: Unifier<IdentCtx, LitVal, OptCons<Ident>> = Unifier::new();
        for (par, arg) in rule_ctx.head.iter().zip(call.args.iter()) {
            if unifier.unify(par, arg).is_err() {
                return None;
            }
        }

        let mut new_brch = brch.clone();
        new_brch.depth += 1;
        new_brch.remove(call_idx);

        for (prim, args) in &rule_ctx.prims {
            new_brch.prims.push((*prim, args.clone()));
        }

        if !super::progagate::propagate_unify(&mut new_brch.prims, &mut unifier) {
            return None;
        }

        for (pred, polys, args) in rule_ctx.calls.iter().rev() {
            let mut new_call = PredCall {
                pred: *pred,
                polys: polys.clone(),
                args: args.clone(),
                looks: (0..self.prog.preds[pred].rules.len()).collect(),
                depth: call.depth + 1,
            };

            new_call.lookahead_update(&self.prog.preds[pred].rules);
            new_brch.insert(call_idx, new_call);
        }

        for call in &mut new_brch.calls {
            let mut dirty_flag = false;
            for arg in &mut call.args {
                if let Some(new_arg) = unifier.subst_opt(arg) {
                    *arg = new_arg;
                    dirty_flag = true;
                }
            }
            // update look-ahead information if any information is propagated
            if dirty_flag {
                call.lookahead_update(&self.prog.preds[&call.pred].rules);
            }
        }

        for (_par, val) in &mut new_brch.answers {
            *val = unifier.subst(val);
        }

        Some(new_brch)
    }

    pub fn run_iddfs_loop(&mut self, entry: Ident) -> usize {
        for depth_limit in
            (self.config.depth_step..=self.config.depth_limit).step_by(self.config.depth_step)
        {
            writeln!(
                self.output.stat,
                "[RUN]: try depth = {}... (found answer: {})",
                depth_limit, self.ansr_cnt
            )
            .unwrap();

            self.reset();
            self.init_stack(entry);
            self.run_dfs_with_depth(depth_limit - self.config.depth_step + 1, depth_limit);

            let stat_res = self.stats.print_stat();
            writeln!(self.output.stat, "{stat_res}").unwrap();

            if self.ansr_cnt >= self.config.answer_limit {
                return self.ansr_cnt;
            }
        }
        self.ansr_cnt
    }
}

#[test]
fn test_runner() {
    let src: &'static str = r#"
datatype IntList where
| Cons(Int, IntList)
| Nil
end

function append(xs: IntList, x: Int) -> IntList
begin
    match xs with
    | Cons(head, tail) => Cons(head, append(tail, x))
    | Nil => Cons(x, Nil)
    end
end

function is_elem(xs: IntList, x: Int) -> Bool
begin
    match xs with
    | Cons(head, tail) => if head == x then true else is_elem(tail, x) 
    | Nil => false
    end
end

function is_elem_after_append(xs: IntList, x: Int) -> Bool
begin
    guard is_elem(append(xs, x), x) = false;
    true
end

query is_elem_after_append(depth_step=5, depth_limit=50, answer_limit=100)
    "#;

    let (mut prog, errs) = crate::syntax::parser::parse_program(src);
    assert!(errs.is_empty());

    let errs = crate::tych::rename::rename_pass(&mut prog);
    assert!(errs.is_empty());

    let errs = crate::tych::check::check_pass(&prog);
    assert!(errs.is_empty());

    let mut prog = crate::logic::compile::compile_pass(&prog);
    crate::logic::elaborate::elaborate_pass(&mut prog);

    // println!("{:#?}", prog);

    let mut pipe_io = OutputWriter::empty();
    let mut runner = RunnerState::new(
        &prog,
        &mut pipe_io,
        &args::get_test_cli_args(std::path::PathBuf::new()),
    );
    let query = &prog.querys[0];

    for param in query.params.iter() {
        runner.config_set_param(param);
    }
    runner.run_iddfs_loop(query.entry);
}