amazeing 0.8.1

Amazeing is a maze generator/solver application with simulation/visualization.
use crate::maze::{BLOCK, Maze, OPEN, UnitShape};
use crate::util::IsDivisible;
use std::cmp::Ordering;
use std::ops::{Add, Sub};

#[derive(Debug, Copy, Clone)]
pub struct NodeFactory {
    pub rows: usize,
    pub cols: usize,
}

impl NodeFactory {
    pub fn new(rows: usize, cols: usize) -> Self {
        Self { rows, cols }
    }

    pub fn at(&self, row: usize, col: usize) -> Option<Node> {
        if row >= self.rows || col >= self.cols {
            None
        } else {
            Some(Node {
                row,
                col,
                rows: self.rows,
                cols: self.cols,
            })
        }
    }
}

#[derive(Default, Debug, Copy, Clone, PartialOrd, PartialEq, Eq, Hash, Ord)]
pub struct Node {
    pub row: usize,
    pub col: usize,
    rows: usize,
    cols: usize,
}

impl Add<(usize, usize)> for Node {
    type Output = Option<Self>;

    fn add(self, rhs: (usize, usize)) -> Self::Output {
        let row = self.row + rhs.0;
        let col = self.col + rhs.1;

        if row >= self.rows || col >= self.cols {
            None
        } else {
            Some(Node {
                row,
                col,
                rows: self.rows,
                cols: self.cols,
            })
        }
    }
}

impl Sub<(usize, usize)> for Node {
    type Output = Option<Self>;

    fn sub(self, rhs: (usize, usize)) -> Self::Output {
        if self.row < rhs.0 || self.col < rhs.1 {
            None
        } else {
            Some(Self {
                row: self.row - rhs.0,
                col: self.col - rhs.1,
                ..self
            })
        }
    }
}

impl Node {
    pub fn left(self, steps: usize) -> Box<dyn Fn(Node) -> Option<Node>> {
        Box::new(move |n| n - (0, steps))
    }

    pub fn right(self, steps: usize) -> Box<dyn Fn(Node) -> Option<Node>> {
        Box::new(move |n| n + (0, steps))
    }

    pub fn up(self, steps: usize) -> Box<dyn Fn(Node) -> Option<Node>> {
        Box::new(move |n| n - (steps, 0))
    }

    pub fn down(self, steps: usize) -> Box<dyn Fn(Node) -> Option<Node>> {
        Box::new(move |n| n + (steps, 0))
    }

    pub fn left_up(self, steps: usize) -> Box<dyn Fn(Node) -> Option<Node>> {
        Box::new(move |n| n - (steps, steps))
    }

    pub fn left_down(self, steps: usize) -> Box<dyn Fn(Node) -> Option<Node>> {
        Box::new(move |n| if let Some(data) = n + (steps, 0) { data - (0, steps) } else { None })
    }

    pub fn right_up(self, steps: usize) -> Box<dyn Fn(Node) -> Option<Node>> {
        Box::new(move |n| if let Some(data) = n - (steps, 0) { data + (0, steps) } else { None })
    }

    pub fn right_down(self, steps: usize) -> Box<dyn Fn(Node) -> Option<Node>> {
        Box::new(move |n| n + (steps, steps))
    }

    pub fn neighbours(self, unit_shape: &UnitShape) -> Vec<Node> {
        match unit_shape {
            UnitShape::Triangle => match self.row % 4 {
                0 => vec![self.down(1), self.left_down(1), self.up(1)],
                1 => vec![self.right_up(1), self.down(1), self.up(1)],
                2 => vec![self.right_down(1), self.down(1), self.up(1)],
                3 => vec![self.up(1), self.down(1), self.left_up(1)],
                _ => unreachable!(),
            },
            UnitShape::Square | UnitShape::Octagon => vec![self.right(1), self.down(1), self.left(1), self.up(1)],
            UnitShape::Rhombus => {
                if self.row.is_even() {
                    vec![self.down(1), self.left_down(1), self.left_up(1), self.up(1)]
                } else {
                    vec![self.right_down(1), self.down(1), self.up(1), self.right_up(1)]
                }
            }
            UnitShape::Hexagon => {
                if self.row.is_even() {
                    vec![self.right(1), self.down(1), self.left_down(1), self.left(1), self.left_up(1), self.up(1)]
                } else {
                    vec![self.right(1), self.right_down(1), self.down(1), self.left(1), self.up(1), self.right_up(1)]
                }
            }
            UnitShape::OctagonSquare => {
                if self.row.is_even() {
                    vec![
                        self.right(1),
                        self.down(1),
                        self.down(2),
                        self.left_down(1),
                        self.left(1),
                        self.left_up(1),
                        self.up(2),
                        self.up(1),
                    ]
                } else {
                    vec![self.right_down(1), self.down(1), self.up(1), self.right_up(1)]
                }
            }
            UnitShape::HexagonRectangle => {
                if self.row.is_even() {
                    vec![self.right(1), self.down(1), self.left_down(1), self.left(1), self.left_up(1), self.up(1)]
                } else {
                    vec![self.right_down(1), self.down(1), self.up(1), self.right_up(1)]
                }
            }
        }
        .iter()
        .filter_map(|i| i(self))
        .collect()
    }

    pub fn neighbours_open(self, maze: &Maze, unit_shape: &UnitShape) -> Vec<Node> {
        self.neighbours(unit_shape).into_iter().filter(|p| maze[*p] == OPEN).collect()
    }

    pub fn neighbours_block(self, maze: &Maze, unit_shape: &UnitShape) -> Vec<Node> {
        self.neighbours(unit_shape).into_iter().filter(|p| maze[*p] == BLOCK).collect()
    }
}

pub trait DNodeWeighted: Ord {
    fn new(node: Node, cost: u32, heu_cost: u32) -> Self;
    fn node(&self) -> Node;
    fn cost(&self) -> u32;
    fn heu_cost(&self) -> u32;
}

#[derive(Debug, Clone, Eq, PartialEq)]
pub struct DNodeWeightedForward {
    pub(crate) node: Node,
    pub(crate) cost: u32,
    pub(crate) heu_cost: u32,
}

impl PartialOrd<Self> for DNodeWeightedForward {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        Some(self.cmp(other))
    }
}

impl Ord for DNodeWeightedForward {
    fn cmp(&self, other: &Self) -> Ordering {
        other.heu_cost.cmp(&self.heu_cost)
    }
}

impl DNodeWeighted for DNodeWeightedForward {
    fn new(node: Node, cost: u32, heu_cost: u32) -> Self {
        Self { node, cost, heu_cost }
    }

    fn node(&self) -> Node {
        self.node
    }

    fn cost(&self) -> u32 {
        self.cost
    }

    fn heu_cost(&self) -> u32 {
        self.heu_cost
    }
}

#[derive(Debug, Clone, Eq, PartialEq)]
pub struct DNodeWeightedBackward {
    pub(crate) node: Node,
    pub(crate) cost: u32,
    pub(crate) heu_cost: u32,
}

impl PartialOrd<Self> for DNodeWeightedBackward {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        Some(self.cmp(other))
    }
}

impl Ord for DNodeWeightedBackward {
    fn cmp(&self, other: &Self) -> Ordering {
        other.heu_cost.cmp(&self.heu_cost).reverse()
    }
}

impl DNodeWeighted for DNodeWeightedBackward {
    fn new(node: Node, cost: u32, heu_cost: u32) -> Self {
        Self { node, cost, heu_cost }
    }

    fn node(&self) -> Node {
        self.node
    }

    fn cost(&self) -> u32 {
        self.cost
    }

    fn heu_cost(&self) -> u32 {
        self.heu_cost
    }
}

#[derive(Debug, Copy, Clone, PartialEq)]
pub enum WeightDirection {
    Forward,
    Backward,
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::maze::{BLOCK, Maze, UnitShape};
    use std::collections::BinaryHeap;

    #[test]
    fn node_factory_bounds_are_enforced() {
        let f = NodeFactory::new(2, 3);
        assert!(f.at(1, 2).is_some());
        assert!(f.at(2, 0).is_none());
        assert!(f.at(0, 3).is_none());
    }

    #[test]
    fn node_add_and_sub_respect_grid_bounds() {
        let f = NodeFactory::new(3, 3);
        let n = f.at(1, 1).unwrap();
        assert_eq!((n + (1, 1)).unwrap(), f.at(2, 2).unwrap());
        assert!((n + (2, 0)).is_none());
        assert_eq!((n - (1, 1)).unwrap(), f.at(0, 0).unwrap());
        assert!((n - (2, 0)).is_none());
    }

    #[test]
    fn neighbours_cover_shape_specific_rules() {
        let f = NodeFactory::new(5, 5);
        let center_even = f.at(2, 2).unwrap();
        let center_odd = f.at(3, 2).unwrap();

        assert_eq!(center_even.neighbours(&UnitShape::Square).len(), 4);
        assert_eq!(center_even.neighbours(&UnitShape::Octagon).len(), 4);
        assert_eq!(center_even.neighbours(&UnitShape::Rhombus).len(), 4);
        assert_eq!(center_odd.neighbours(&UnitShape::Rhombus).len(), 4);
        assert_eq!(center_even.neighbours(&UnitShape::Hexagon).len(), 6);
        assert_eq!(center_odd.neighbours(&UnitShape::Hexagon).len(), 6);
        assert_eq!(center_even.neighbours(&UnitShape::HexagonRectangle).len(), 6);
        assert_eq!(center_odd.neighbours(&UnitShape::HexagonRectangle).len(), 4);
        assert_eq!(center_even.neighbours(&UnitShape::OctagonSquare).len(), 8);
        assert_eq!(center_odd.neighbours(&UnitShape::OctagonSquare).len(), 4);

        assert_eq!(
            center_even.neighbours(&UnitShape::Rhombus),
            vec![f.at(3, 2).unwrap(), f.at(3, 1).unwrap(), f.at(1, 1).unwrap(), f.at(1, 2).unwrap(),]
        );
        assert_eq!(
            center_odd.neighbours(&UnitShape::Rhombus),
            vec![f.at(4, 3).unwrap(), f.at(4, 2).unwrap(), f.at(2, 2).unwrap(), f.at(2, 3).unwrap(),]
        );
    }

    #[test]
    fn neighbours_open_and_block_filter_cells() {
        let mut maze = Maze::new(UnitShape::Square, 3, 3, BLOCK);
        let f = NodeFactory::new(3, 3);
        let c = f.at(1, 1).unwrap();
        maze[f.at(1, 2).unwrap()] = OPEN;
        maze[f.at(0, 1).unwrap()] = OPEN;

        let open = c.neighbours_open(&maze, &UnitShape::Square);
        let block = c.neighbours_block(&maze, &UnitShape::Square);

        assert_eq!(open.len(), 2);
        assert_eq!(block.len(), 2);
    }

    #[test]
    fn weighted_node_ordering_works_for_heap() {
        let f = NodeFactory::new(2, 2);
        let a = f.at(0, 0).unwrap();
        let b = f.at(0, 1).unwrap();

        let mut forward = BinaryHeap::new();
        forward.push(DNodeWeightedForward::new(a, 0, 1));
        forward.push(DNodeWeightedForward::new(b, 0, 5));
        assert_eq!(forward.pop().unwrap().heu_cost(), 1);

        let mut backward = BinaryHeap::new();
        backward.push(DNodeWeightedBackward::new(a, 0, 1));
        backward.push(DNodeWeightedBackward::new(b, 0, 5));
        assert_eq!(backward.pop().unwrap().heu_cost(), 5);
    }
}