use crate::search_graph::DepthFirstNumber;
use crate::search_graph::SearchGraph;
use crate::solve::{SolveDatabase, SolveIteration};
use crate::stack::{Stack, StackDepth};
use crate::{combine, Guidance, Minimums, Solution, UCanonicalGoal};
use chalk_ir::interner::Interner;
use chalk_ir::Fallible;
use chalk_ir::{Canonical, ConstrainedSubst, Constraints, Goal, InEnvironment, UCanonical};
use chalk_solve::{coinductive_goal::IsCoinductive, RustIrDatabase};
use rustc_hash::FxHashMap;
use std::fmt;
use tracing::debug;
use tracing::{info, instrument};
struct RecursiveContext<I: Interner> {
stack: Stack,
search_graph: SearchGraph<I>,
cache: FxHashMap<UCanonicalGoal<I>, Fallible<Solution<I>>>,
caching_enabled: bool,
}
struct Solver<'me, I: Interner> {
program: &'me dyn RustIrDatabase<I>,
context: &'me mut RecursiveContext<I>,
}
pub struct RecursiveSolver<I: Interner> {
ctx: Box<RecursiveContext<I>>,
}
impl<I: Interner> RecursiveSolver<I> {
pub fn new(overflow_depth: usize, caching_enabled: bool) -> Self {
Self {
ctx: Box::new(RecursiveContext::new(overflow_depth, caching_enabled)),
}
}
}
impl<I: Interner> fmt::Debug for RecursiveSolver<I> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "RecursiveSolver")
}
}
trait MergeWith<T> {
fn merge_with<F>(self, other: Self, f: F) -> Self
where
F: FnOnce(T, T) -> T;
}
impl<T> MergeWith<T> for Fallible<T> {
fn merge_with<F>(self: Fallible<T>, other: Fallible<T>, f: F) -> Fallible<T>
where
F: FnOnce(T, T) -> T,
{
match (self, other) {
(Err(_), Ok(v)) | (Ok(v), Err(_)) => Ok(v),
(Ok(v1), Ok(v2)) => Ok(f(v1, v2)),
(Err(_), Err(e)) => Err(e),
}
}
}
impl<I: Interner> RecursiveContext<I> {
pub fn new(overflow_depth: usize, caching_enabled: bool) -> Self {
RecursiveContext {
stack: Stack::new(overflow_depth),
search_graph: SearchGraph::new(),
cache: FxHashMap::default(),
caching_enabled,
}
}
pub(crate) fn solver<'me>(
&'me mut self,
program: &'me dyn RustIrDatabase<I>,
) -> Solver<'me, I> {
Solver {
program,
context: self,
}
}
}
impl<'me, I: Interner> Solver<'me, I> {
pub(crate) fn solve_root_goal(
&mut self,
canonical_goal: &UCanonicalGoal<I>,
) -> Fallible<Solution<I>> {
debug!("solve_root_goal(canonical_goal={:?})", canonical_goal);
assert!(self.context.stack.is_empty());
let minimums = &mut Minimums::new();
self.solve_goal(canonical_goal.clone(), minimums)
}
#[instrument(level = "debug", skip(self))]
fn solve_new_subgoal(
&mut self,
canonical_goal: UCanonicalGoal<I>,
depth: StackDepth,
dfn: DepthFirstNumber,
) -> Minimums {
loop {
let minimums = &mut Minimums::new();
let (current_answer, current_prio) = self.solve_iteration(&canonical_goal, minimums);
debug!(
"solve_new_subgoal: loop iteration result = {:?} with minimums {:?}",
current_answer, minimums
);
if !self.context.stack[depth].read_and_reset_cycle_flag() {
self.context.search_graph[dfn].solution = current_answer;
self.context.search_graph[dfn].solution_priority = current_prio;
return *minimums;
}
let old_answer = &self.context.search_graph[dfn].solution;
let old_prio = self.context.search_graph[dfn].solution_priority;
let (current_answer, current_prio) = combine::with_priorities_for_goal(
self.program.interner(),
&canonical_goal.canonical.value.goal,
old_answer.clone(),
old_prio,
current_answer,
current_prio,
);
if self.context.search_graph[dfn].solution == current_answer {
return *minimums;
}
let current_answer_is_ambig = match ¤t_answer {
Ok(s) => s.is_ambig(),
Err(_) => false,
};
self.context.search_graph[dfn].solution = current_answer;
self.context.search_graph[dfn].solution_priority = current_prio;
if current_answer_is_ambig {
return *minimums;
}
self.context.search_graph.rollback_to(dfn + 1);
}
}
}
impl<'me, I: Interner> SolveDatabase<I> for Solver<'me, I> {
#[instrument(level = "info", skip(self, minimums))]
fn solve_goal(
&mut self,
goal: UCanonicalGoal<I>,
minimums: &mut Minimums,
) -> Fallible<Solution<I>> {
if let Some(value) = self.context.cache.get(&goal) {
debug!("solve_reduced_goal: cache hit, value={:?}", value);
return value.clone();
}
if let Some(dfn) = self.context.search_graph.lookup(&goal) {
if let Some(depth) = self.context.search_graph[dfn].stack_depth {
if self.context.stack.coinductive_cycle_from(depth) {
let value = ConstrainedSubst {
subst: goal.trivial_substitution(self.program.interner()),
constraints: Constraints::empty(self.program.interner()),
};
debug!("applying coinductive semantics");
return Ok(Solution::Unique(Canonical {
value,
binders: goal.canonical.binders,
}));
}
self.context.stack[depth].flag_cycle();
}
minimums.update_from(self.context.search_graph[dfn].links);
let previous_solution = self.context.search_graph[dfn].solution.clone();
let previous_solution_priority = self.context.search_graph[dfn].solution_priority;
info!(
"solve_goal: cycle detected, previous solution {:?} with prio {:?}",
previous_solution, previous_solution_priority
);
previous_solution
} else {
let coinductive_goal = goal.is_coinductive(self.program);
let depth = self.context.stack.push(coinductive_goal);
let dfn = self.context.search_graph.insert(&goal, depth);
let subgoal_minimums = self.solve_new_subgoal(goal, depth, dfn);
self.context.search_graph[dfn].links = subgoal_minimums;
self.context.search_graph[dfn].stack_depth = None;
self.context.stack.pop(depth);
minimums.update_from(subgoal_minimums);
let result = self.context.search_graph[dfn].solution.clone();
let priority = self.context.search_graph[dfn].solution_priority;
if subgoal_minimums.positive >= dfn {
if self.context.caching_enabled {
self.context
.search_graph
.move_to_cache(dfn, &mut self.context.cache);
debug!("solve_reduced_goal: SCC head encountered, moving to cache");
} else {
debug!(
"solve_reduced_goal: SCC head encountered, rolling back as caching disabled"
);
self.context.search_graph.rollback_to(dfn);
}
}
info!("solve_goal: solution = {:?} prio {:?}", result, priority);
result
}
}
fn interner(&self) -> &I {
&self.program.interner()
}
fn db(&self) -> &dyn RustIrDatabase<I> {
self.program
}
}
impl<I: Interner> chalk_solve::Solver<I> for RecursiveSolver<I> {
fn solve(
&mut self,
program: &dyn RustIrDatabase<I>,
goal: &UCanonical<InEnvironment<Goal<I>>>,
) -> Option<chalk_solve::Solution<I>> {
self.ctx
.solver(program)
.solve_root_goal(goal)
.ok()
.map(|s| match s {
Solution::Unique(c) => chalk_solve::Solution::Unique(c),
Solution::Ambig(g) => chalk_solve::Solution::Ambig(match g {
Guidance::Definite(g) => chalk_solve::Guidance::Definite(g),
Guidance::Suggested(g) => chalk_solve::Guidance::Suggested(g),
Guidance::Unknown => chalk_solve::Guidance::Unknown,
}),
})
}
fn solve_limited(
&mut self,
program: &dyn RustIrDatabase<I>,
goal: &UCanonical<InEnvironment<Goal<I>>>,
_should_continue: impl std::ops::Fn() -> bool,
) -> Option<chalk_solve::Solution<I>> {
self.ctx
.solver(program)
.solve_root_goal(goal)
.ok()
.map(|s| match s {
Solution::Unique(c) => chalk_solve::Solution::Unique(c),
Solution::Ambig(g) => chalk_solve::Solution::Ambig(match g {
Guidance::Definite(g) => chalk_solve::Guidance::Definite(g),
Guidance::Suggested(g) => chalk_solve::Guidance::Suggested(g),
Guidance::Unknown => chalk_solve::Guidance::Unknown,
}),
})
}
fn solve_multiple(
&mut self,
_program: &dyn RustIrDatabase<I>,
_goal: &UCanonical<InEnvironment<Goal<I>>>,
_f: impl FnMut(chalk_solve::SubstitutionResult<Canonical<ConstrainedSubst<I>>>, bool) -> bool,
) -> bool {
unimplemented!("Recursive solver doesn't support multiple answers")
}
}