use std::marker::PhantomData;
use num_traits::{AsPrimitive, PrimInt, Unsigned};
use thiserror::Error;
use crate::{
algorithm::{algorithm::Algorithm, direction::Direction, r#move::r#move::Move},
puzzle::{
label::label::RowGrids, sliding_puzzle::SlidingPuzzle, solvable::Solvable,
solved_state::SolvedState,
},
solver::{
heuristic::{manhattan::ManhattanDistance, Heuristic},
statistics::SolverIterationStats,
},
};
#[derive(Clone, Debug, PartialEq, Eq)]
struct Stack {
stack: [Direction; 256],
idx: usize,
}
impl Stack {
fn push(&mut self, direction: Direction) {
self.stack[self.idx] = direction;
self.idx += 1;
}
fn top(&self) -> Option<Direction> {
if self.idx == 0 {
None
} else {
Some(self.stack[self.idx - 1])
}
}
fn pop(&mut self) -> Direction {
self.idx -= 1;
self.stack[self.idx]
}
fn clear(&mut self) {
self.idx = 0;
}
}
impl Default for Stack {
fn default() -> Self {
Self {
stack: [Direction::Up; 256],
idx: 0,
}
}
}
impl From<&Stack> for Algorithm {
fn from(stack: &Stack) -> Self {
Self::with_moves(
stack.stack[..stack.idx]
.iter()
.map(|d| Move::from(*d))
.collect(),
)
}
}
#[derive(Clone, Debug, Error, PartialEq, Eq)]
pub enum SolverError {
#[error("NoSolutionFound: no solution was found within the range searched")]
NoSolutionFound,
#[error("IncompatiblePuzzleSize: the puzzle size is incompatible with the solver")]
IncompatiblePuzzleSize,
#[error("Unsolvable: the puzzle is unsolvable")]
Unsolvable,
}
#[derive(Clone, Debug)]
pub struct Solver<'a, Puzzle, T, S, H> {
stack: Stack,
phantom_puzzle: PhantomData<Puzzle>,
heuristic: &'a H,
solved_state: &'a S,
phantom_t: PhantomData<T>,
}
impl<Puzzle: SlidingPuzzle + Clone> Default
for Solver<'static, Puzzle, u8, RowGrids, ManhattanDistance<'static, RowGrids>>
{
fn default() -> Self {
Self::new_with_t(&ManhattanDistance(&RowGrids), &RowGrids)
}
}
impl<'a, Puzzle, S, H> Solver<'a, Puzzle, u8, S, H> {
pub fn new(heuristic: &'a H, solved_state: &'a S) -> Self {
Self {
stack: Stack::default(),
phantom_puzzle: PhantomData,
heuristic,
solved_state,
phantom_t: PhantomData,
}
}
}
impl<'a, Puzzle, T, S, H> Solver<'a, Puzzle, T, S, H>
where
Puzzle: SlidingPuzzle + Clone,
T: PrimInt + Unsigned + 'static,
S: SolvedState + Solvable,
H: Heuristic<Puzzle, T>,
u8: AsPrimitive<T>,
{
pub fn new_with_t(heuristic: &'a H, solved_state: &'a S) -> Self {
Self {
stack: Stack::default(),
phantom_puzzle: PhantomData,
heuristic,
solved_state,
phantom_t: PhantomData::<T>,
}
}
fn dfs(&mut self, puzzle: &mut Puzzle, depth: T) -> bool {
if depth == T::zero() {
return self.solved_state.is_solved(puzzle);
}
let bound = self.heuristic.bound(puzzle);
if bound > depth {
return false;
}
for d in [
Direction::Up,
Direction::Down,
Direction::Left,
Direction::Right,
] {
if self.stack.top() == Some(d.inverse()) {
continue;
}
if !puzzle.try_move_dir(d) {
continue;
}
self.stack.push(d);
if self.dfs(puzzle, depth - T::one()) {
return true;
}
self.stack.pop();
puzzle.try_move_dir(d.inverse());
}
false
}
fn solve_impl(
&mut self,
puzzle: &Puzzle,
iteration_callback: Option<&dyn Fn(SolverIterationStats)>,
) -> Result<Algorithm, SolverError> {
if !self.solved_state.is_solvable(puzzle) {
return Err(SolverError::Unsolvable);
}
self.stack.clear();
let mut puzzle = puzzle.clone();
let mut depth = self.heuristic.bound(&puzzle);
loop {
if self.dfs(&mut puzzle, depth) {
let mut solution: Algorithm = (&self.stack).into();
solution.simplify();
return Ok(solution);
}
if let Some(f) = iteration_callback {
f(SolverIterationStats {
depth: depth.to_u8().unwrap(),
});
}
if let Some(d) = depth.checked_add(&2u8.as_()) {
depth = d;
} else {
return Err(SolverError::NoSolutionFound);
}
}
}
pub fn solve(&mut self, puzzle: &Puzzle) -> Result<Algorithm, SolverError> {
self.solve_impl(puzzle, None)
}
pub fn solve_with_callback(
&mut self,
puzzle: &Puzzle,
callback: &dyn Fn(SolverIterationStats),
) -> Result<Algorithm, SolverError> {
self.solve_impl(puzzle, Some(callback))
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr as _;
use crate::puzzle::{label::label::Rows, puzzle::Puzzle};
use super::*;
#[test]
fn test_row_grids_manhattan() {
let mut solver = Solver::new(&ManhattanDistance(&RowGrids), &RowGrids);
let puzzle = Puzzle::from_str("8 6 7/2 5 4/3 0 1").unwrap();
let solution = solver.solve(&puzzle).unwrap();
assert_eq!(solution.len_stm::<u64>(), 31);
}
#[test]
fn test_rows_manhattan() {
let mut solver = Solver::new(&ManhattanDistance(&Rows), &Rows);
let puzzle = Puzzle::from_str("8 6 7/2 5 4/3 0 1").unwrap();
let solution = solver.solve(&puzzle).unwrap();
assert_eq!(solution.len_stm::<u64>(), 23);
}
}