use crate::csp::Constraint;
use crate::fill::Fill;
use crate::grid::Grid;
use crate::mdd::Mdd;
use crate::memo::Memo;
use crate::operator::{CommutativeOperator, NonCommutativeOperator};
use crate::polyomino::{Cell, Polyomino};
use crate::table::Table;
use crate::{Error, Error::EmptyFills, N, T};
use std::fmt::{Display, Formatter};
#[derive(
Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize,
)]
pub enum CageOperator {
Add,
Subtract,
Multiply,
Divide,
Given,
}
#[derive(Clone, PartialEq, Eq, Debug)]
enum CageSupport {
Commutative(CommutativeOperator, T, Mdd),
NonCommutative(NonCommutativeOperator, T, Table),
Given(N),
}
#[derive(Debug, Clone)]
pub struct Cage {
pub polyomino: Polyomino,
support: CageSupport,
}
impl Cage {
pub fn new(
n: N,
polyomino: Polyomino,
operation: CageOperator,
target: T,
) -> Result<Self, Error> {
let k = polyomino.len();
let result = match operation {
CageOperator::Add => {
Self::commutative(n, polyomino.clone(), CommutativeOperator::Add, target)
}
CageOperator::Multiply => {
Self::commutative(n, polyomino.clone(), CommutativeOperator::Multiply, target)
}
CageOperator::Subtract => {
if k != 2 {
return Err(Error::InfeasibleCage(polyomino, u64::from(target)));
}
Self::non_commutative(
n,
polyomino.clone(),
NonCommutativeOperator::Subtract,
target,
)
}
CageOperator::Divide => {
if k != 2 || target < 2 {
return Err(Error::InfeasibleCage(polyomino, u64::from(target)));
}
Self::non_commutative(n, polyomino.clone(), NonCommutativeOperator::Divide, target)
}
CageOperator::Given => {
if k != 1 {
return Err(Error::InfeasibleCage(polyomino, u64::from(target)));
}
let Some(&cell) = polyomino.iter().next() else {
return Err(Error::InfeasibleCage(polyomino, u64::from(target)));
};
let value = N::try_from(target)
.map_err(|_| Error::InfeasibleCage(polyomino, u64::from(target)))?;
return Self::given(cell, value);
}
};
result.map_err(|e| match e {
EmptyFills => Error::InfeasibleCage(polyomino, u64::from(target)),
other => other,
})
}
pub fn commutative(
n: N,
polyomino: Polyomino,
operation: CommutativeOperator,
target: T,
) -> Result<Self, Error> {
let k = N::try_from(polyomino.len()).map_err(|_| EmptyFills)?;
let lines = collinear_groups(&polyomino);
let mdd = Mdd::new(n, k, operation, target, &lines)?;
let support = CageSupport::Commutative(operation, target, mdd);
Ok(Self { polyomino, support })
}
pub fn non_commutative(
n: N,
polyomino: Polyomino,
operation: NonCommutativeOperator,
target: T,
) -> Result<Self, Error> {
let table = Table::non_commutative(n, operation, target)?;
let support = CageSupport::NonCommutative(operation, target, table);
Ok(Self { polyomino, support })
}
pub fn given(cell: Cell, n: N) -> Result<Self, Error> {
Ok(Self {
polyomino: Polyomino::from(vec![cell])?,
support: CageSupport::Given(n),
})
}
pub fn get(&self, cell: Cell) -> Result<Fill, Error> {
let index = self.polyomino_index(cell)?;
let fill = match &self.support {
CageSupport::Commutative(_, _, memo) => memo.get(index)?,
CageSupport::NonCommutative(_, _, memo) => memo.get(index)?,
CageSupport::Given(n) => Fill::from(&[*n]),
};
Ok(fill)
}
#[must_use]
pub fn op_target(&self) -> (CageOperator, T) {
match &self.support {
CageSupport::Commutative(op, target, _) => (
match op {
CommutativeOperator::Add => CageOperator::Add,
CommutativeOperator::Multiply => CageOperator::Multiply,
},
*target,
),
CageSupport::NonCommutative(op, target, _) => (
match op {
NonCommutativeOperator::Subtract => CageOperator::Subtract,
NonCommutativeOperator::Divide => CageOperator::Divide,
},
*target,
),
CageSupport::Given(n) => (CageOperator::Given, T::from(*n)),
}
}
fn polyomino_index(&self, cell: Cell) -> Result<usize, Error> {
self.polyomino
.iter()
.position(|c| *c == cell)
.ok_or(Error::MissingCell(cell))
}
#[must_use]
pub const fn polyomino(&self) -> &Polyomino {
&self.polyomino
}
#[must_use]
pub fn operation(&self) -> Operation {
let (operator, target) = self.op_target();
Operation {
operator,
target: u64::from(target),
}
}
#[must_use]
pub fn contains(&self, cell: Cell) -> bool {
self.polyomino.contains(&cell)
}
#[must_use]
pub fn cells(&self) -> Vec<Cell> {
self.polyomino.cells()
}
pub fn viable_counts(&self, fills: &[Fill]) -> Result<(u64, u64), Error> {
match &self.support {
CageSupport::Given(value) => {
let viable = fills.first().is_some_and(|fill| fill.contains(*value));
Ok(if viable { (1, 1) } else { (0, 0) })
}
CageSupport::Commutative(_, _, memo) => match memo.narrow(fills) {
Ok(narrowed) => Ok((narrowed.multiset_count(), narrowed.tuple_count())),
Err(EmptyFills) => Ok((0, 0)),
Err(e) => Err(e),
},
CageSupport::NonCommutative(_, _, memo) => match memo.narrow(fills) {
Ok(narrowed) => {
let tuples = narrowed.tuples();
let multisets: std::collections::HashSet<Vec<N>> = tuples
.iter()
.map(|tuple| {
let mut sorted = tuple.clone();
sorted.sort_unstable();
sorted
})
.collect();
let count = |len: usize| u64::try_from(len).unwrap_or(u64::MAX);
Ok((count(multisets.len()), count(tuples.len())))
}
Err(EmptyFills) => Ok((0, 0)),
Err(e) => Err(e),
},
}
}
}
impl Display for Cage {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Cage({} {})",
self.operation(),
self.cells()
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(", ")
)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Operation {
pub operator: CageOperator,
pub target: u64,
}
impl Operation {
#[must_use]
pub const fn new(operator: CageOperator, target: u64) -> Self {
Self { operator, target }
}
}
impl Display for CageOperator {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Add => write!(f, "+"),
Self::Subtract => write!(f, "−"),
Self::Multiply => write!(f, "×"),
Self::Divide => write!(f, "÷"),
Self::Given => write!(f, "="),
}
}
}
impl Display for Operation {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
if self.operator == CageOperator::Given {
write!(f, "{}", self.target)
} else {
write!(f, "{}{}", self.operator, self.target)
}
}
}
pub fn collinear_groups(polyomino: &Polyomino) -> Vec<Vec<usize>> {
let cells: Vec<Cell> = polyomino.iter().copied().collect();
let mut by_row: std::collections::HashMap<usize, Vec<usize>> = std::collections::HashMap::new();
let mut by_col: std::collections::HashMap<usize, Vec<usize>> = std::collections::HashMap::new();
for (i, &Cell(r, c)) in cells.iter().enumerate() {
by_row.entry(r).or_default().push(i);
by_col.entry(c).or_default().push(i);
}
by_row
.into_values()
.chain(by_col.into_values())
.filter(|g| g.len() >= 2)
.collect()
}
fn narrow_fills<M: Memo>(
memo: &M,
old_fills: &[Fill],
n: usize,
grid_n: usize,
) -> Result<Vec<Fill>, Error> {
let full = Fill::all(grid_n);
if old_fills.iter().all(|&f| f == full) {
return (0..n).map(|i| memo.get(i)).collect();
}
match memo.narrow(old_fills) {
Ok(narrowed) => Ok((0..n)
.map(|i| narrowed.get(i).unwrap_or_default())
.collect()),
Err(EmptyFills) => Ok(vec![Fill::default(); n]),
Err(e) => Err(e),
}
}
impl Constraint<Grid, Cell, Fill, Error> for Cage {
fn propagate(&self, state: &Grid) -> Result<(Grid, Vec<Cell>), Error> {
let cells: Vec<Cell> = self.polyomino.iter().copied().collect();
let k = cells.len();
let old_fills: Vec<Fill> = cells
.iter()
.map(|&c| state.get(c))
.collect::<Result<_, _>>()?;
let new_fills = match &self.support {
CageSupport::Given(n) => {
let singleton = Fill::from(&[*n]);
vec![if old_fills[0].contains(*n) {
singleton
} else {
Fill::default()
}]
}
CageSupport::Commutative(_, _, memo) => {
narrow_fills(memo, &old_fills, k, state.size())?
}
CageSupport::NonCommutative(_, _, memo) => {
narrow_fills(memo, &old_fills, k, state.size())?
}
};
Ok(state.apply_fills(&cells, &old_fills, new_fills))
}
fn in_scope(&self, variable: Cell) -> bool {
self.polyomino.contains(&variable)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::grid::Grid;
use crate::operator::CommutativeOperator::{Add, Multiply};
use crate::operator::NonCommutativeOperator::{Divide, Subtract};
fn domino(r0: usize, c0: usize, r1: usize, c1: usize) -> Polyomino {
Polyomino::from([Cell(r0, c0), Cell(r1, c1)]).unwrap()
}
fn triomino(r0: usize, c0: usize, r1: usize, c1: usize, r2: usize, c2: usize) -> Polyomino {
Polyomino::from([Cell(r0, c0), Cell(r1, c1), Cell(r2, c2)]).unwrap()
}
#[test]
fn commutative_add_succeeds() {
assert!(Cage::commutative(4, domino(1, 1, 1, 2), Add, 5).is_ok());
}
#[test]
fn commutative_multiply_succeeds() {
assert!(Cage::commutative(4, domino(1, 1, 1, 2), Multiply, 6).is_ok());
}
#[test]
fn commutative_triple_succeeds() {
assert!(Cage::commutative(4, triomino(1, 1, 1, 2, 1, 3), Add, 6).is_ok());
}
#[test]
fn commutative_infeasible_target_returns_empty_fills() {
assert!(matches!(
Cage::commutative(4, domino(1, 1, 1, 2), Add, 9),
Err(EmptyFills)
));
}
#[test]
fn commutative_stores_polyomino() {
let poly = domino(1, 1, 1, 2);
let cage = Cage::commutative(4, poly.clone(), Add, 5).unwrap();
assert_eq!(cage.polyomino, poly);
}
#[test]
fn non_commutative_subtract_succeeds() {
assert!(Cage::non_commutative(4, domino(1, 1, 1, 2), Subtract, 1).is_ok());
}
#[test]
fn non_commutative_divide_succeeds() {
assert!(Cage::non_commutative(4, domino(1, 1, 1, 2), Divide, 2).is_ok());
}
#[test]
fn non_commutative_infeasible_target_returns_empty_fills() {
assert!(matches!(
Cage::non_commutative(4, domino(1, 1, 1, 2), Subtract, 4),
Err(EmptyFills)
));
}
#[test]
fn non_commutative_stores_polyomino() {
let poly = domino(2, 1, 2, 2);
let cage = Cage::non_commutative(4, poly.clone(), Subtract, 1).unwrap();
assert_eq!(cage.polyomino, poly);
}
fn full_grid(n: usize) -> Grid {
Grid::new(n).unwrap()
}
#[test]
fn cage_propagate_given_pins_cell() {
let cage = Cage::given(Cell(1, 1), 3).unwrap();
let (new_g, changed) = cage.propagate(&full_grid(4)).unwrap();
assert_eq!(new_g.get(Cell(1, 1)).unwrap(), Fill::from(&[3]));
assert_eq!(changed, vec![Cell(1, 1)]);
}
#[test]
fn cage_propagate_add_prunes_impossible_values() {
let cage = Cage::commutative(4, domino(1, 1, 1, 2), Add, 3).unwrap();
let (new_g, _) = cage.propagate(&full_grid(4)).unwrap();
assert_eq!(new_g.get(Cell(1, 1)).unwrap(), Fill::from(&[1, 2]));
assert_eq!(new_g.get(Cell(1, 2)).unwrap(), Fill::from(&[1, 2]));
}
#[test]
fn cage_propagate_cross_cell_add_prunes_partner() {
let cage = Cage::commutative(4, domino(1, 1, 1, 2), Add, 5).unwrap();
let g = full_grid(4).set(Cell(1, 1), Fill::from(&[4]));
let (new_g, changed) = cage.propagate(&g).unwrap();
assert_eq!(new_g.get(Cell(1, 2)).unwrap(), Fill::from(&[1]));
assert!(changed.contains(&Cell(1, 2)));
}
#[test]
fn cage_propagate_cross_cell_subtract_prunes_partner() {
let cage = Cage::non_commutative(4, domino(1, 1, 1, 2), Subtract, 3).unwrap();
let g = full_grid(4).set(Cell(1, 1), Fill::from(&[4]));
let (new_g, _) = cage.propagate(&g).unwrap();
assert_eq!(new_g.get(Cell(1, 2)).unwrap(), Fill::from(&[1]));
}
#[test]
fn cage_propagate_no_valid_tuple_empties_values() {
let g = full_grid(4)
.set(Cell(1, 1), Fill::from(&[4]))
.set(Cell(1, 2), Fill::from(&[4]));
let cage = Cage::commutative(4, domino(1, 1, 1, 2), Add, 3).unwrap();
let (new_g, changed) = cage.propagate(&g).unwrap();
assert!(new_g.get(Cell(1, 1)).unwrap().is_empty());
assert!(new_g.get(Cell(1, 2)).unwrap().is_empty());
assert_eq!(changed.len(), 2);
}
struct NoNarrow<M: Memo>(M);
impl<M: Memo> Memo for NoNarrow<M> {
fn get(&self, index: usize) -> Result<Fill, Error> {
self.0.get(index)
}
fn narrow(&self, _support: &[Fill]) -> Result<Self, Error> {
panic!("narrow must not be called when every input fill is full")
}
}
#[test]
fn narrow_fills_all_full_skips_narrow_and_returns_base_fills() {
let poly = domino(1, 1, 1, 2);
let mdd = Mdd::new(4, 2, Add, 3, &collinear_groups(&poly)).unwrap();
let full = vec![Fill::all(4); 2];
let fills = narrow_fills(&NoNarrow(mdd), &full, 2, 4).unwrap();
assert_eq!(fills, vec![Fill::from(&[1, 2]); 2]);
}
#[test]
fn narrow_fills_all_full_matches_narrow_path() {
let poly = triomino(1, 1, 1, 2, 2, 1);
let mdd = Mdd::new(4, 3, Add, 7, &collinear_groups(&poly)).unwrap();
let full = vec![Fill::all(4); 3];
let shortcut = narrow_fills(&mdd, &full, 3, 4).unwrap();
let narrowed = mdd.narrow(&full).unwrap();
let via_narrow: Vec<Fill> = (0..3).map(|i| narrowed.get(i).unwrap()).collect();
assert_eq!(shortcut, via_narrow);
}
#[test]
fn narrow_fills_partial_input_still_narrows() {
let poly = domino(1, 1, 1, 2);
let mdd = Mdd::new(4, 2, Add, 5, &collinear_groups(&poly)).unwrap();
let fills = narrow_fills(&mdd, &[Fill::from(&[4]), Fill::all(4)], 2, 4).unwrap();
assert_eq!(fills, vec![Fill::from(&[4]), Fill::from(&[1])]);
}
#[test]
fn cage_propagate_subtract_full_grid_returns_base_fills() {
let cage = Cage::non_commutative(4, domino(1, 1, 1, 2), Subtract, 3).unwrap();
let (new_g, _) = cage.propagate(&full_grid(4)).unwrap();
assert_eq!(new_g.get(Cell(1, 1)).unwrap(), Fill::from(&[1, 4]));
assert_eq!(new_g.get(Cell(1, 2)).unwrap(), Fill::from(&[1, 4]));
}
#[test]
fn given_succeeds() {
assert!(Cage::given(Cell(1, 1), 3).is_ok());
}
#[test]
fn given_stores_singleton_polyomino() {
let cage = Cage::given(Cell(2, 3), 5).unwrap();
assert!(cage.polyomino.contains(&Cell(2, 3)));
assert_eq!(cage.polyomino.len(), 1);
}
#[test]
fn given_stores_target_as_value() {
let cage = Cage::given(Cell(1, 1), 7).unwrap();
assert_eq!(cage.support, CageSupport::Given(7));
}
#[test]
fn get_missing_cell_returns_error() {
let cage = Cage::commutative(4, domino(1, 1, 1, 2), Add, 5).unwrap();
assert!(matches!(cage.get(Cell(9, 9)), Err(Error::MissingCell(_))));
}
#[test]
fn get_given_returns_singleton_fill() {
let cage = Cage::given(Cell(1, 1), 3).unwrap();
assert_eq!(cage.get(Cell(1, 1)).unwrap(), Fill::from(&[3]));
}
#[test]
fn get_commutative_returns_base_fill() {
let cage = Cage::commutative(4, domino(1, 1, 1, 2), Add, 3).unwrap();
assert_eq!(cage.get(Cell(1, 1)).unwrap(), Fill::from(&[1, 2]));
}
#[test]
fn get_non_commutative_returns_base_fill() {
let cage = Cage::non_commutative(4, domino(1, 1, 1, 2), Subtract, 3).unwrap();
assert_eq!(cage.get(Cell(1, 2)).unwrap(), Fill::from(&[1, 4]));
}
#[test]
fn new_subtract_wrong_arity_is_infeasible() {
assert!(matches!(
Cage::new(4, triomino(1, 1, 1, 2, 1, 3), CageOperator::Subtract, 1),
Err(Error::InfeasibleCage(_, 1))
));
}
#[test]
fn new_divide_wrong_arity_is_infeasible() {
assert!(matches!(
Cage::new(4, triomino(1, 1, 1, 2, 1, 3), CageOperator::Divide, 2),
Err(Error::InfeasibleCage(_, 2))
));
}
#[test]
fn new_given_wrong_arity_is_infeasible() {
assert!(matches!(
Cage::new(4, domino(1, 1, 1, 2), CageOperator::Given, 1),
Err(Error::InfeasibleCage(_, 1))
));
}
#[test]
fn new_given_target_out_of_value_range_is_infeasible() {
let poly = Polyomino::from([Cell(1, 1)]).unwrap();
assert!(matches!(
Cage::new(4, poly, CageOperator::Given, 1000),
Err(Error::InfeasibleCage(_, 1000))
));
}
#[test]
fn new_add_unreachable_target_maps_to_infeasible_cage() {
assert!(matches!(
Cage::new(4, domino(1, 1, 1, 2), CageOperator::Add, 9),
Err(Error::InfeasibleCage(_, 9))
));
}
#[test]
fn new_builds_every_operator() {
let p = domino(1, 1, 1, 2);
assert!(Cage::new(4, p.clone(), CageOperator::Add, 5).is_ok());
assert!(Cage::new(4, p.clone(), CageOperator::Subtract, 1).is_ok());
assert!(Cage::new(4, p.clone(), CageOperator::Multiply, 6).is_ok());
assert!(Cage::new(4, p, CageOperator::Divide, 2).is_ok());
let single = Polyomino::from([Cell(1, 1)]).unwrap();
assert!(Cage::new(4, single, CageOperator::Given, 3).is_ok());
}
#[test]
fn op_target_round_trips_every_operator() {
let p = domino(1, 1, 1, 2);
let cases = [
(CageOperator::Add, 5),
(CageOperator::Subtract, 1),
(CageOperator::Multiply, 6),
(CageOperator::Divide, 2),
];
for (op, target) in cases {
let cage = Cage::new(4, p.clone(), op, target).unwrap();
assert_eq!(cage.op_target(), (op, target));
}
let given = Cage::given(Cell(1, 1), 3).unwrap();
assert_eq!(given.op_target(), (CageOperator::Given, 3));
}
#[test]
fn polyomino_accessor_returns_cells() {
let p = domino(1, 1, 1, 2);
let cage = Cage::commutative(4, p.clone(), Add, 5).unwrap();
assert_eq!(cage.polyomino(), &p);
}
#[test]
fn operation_accessor_combines_operator_and_target() {
let cage = Cage::commutative(4, domino(1, 1, 1, 2), Multiply, 6).unwrap();
assert_eq!(cage.operation(), Operation::new(CageOperator::Multiply, 6));
}
#[test]
fn contains_member_and_non_member() {
let cage = Cage::commutative(4, domino(1, 1, 1, 2), Add, 5).unwrap();
assert!(cage.contains(Cell(1, 1)));
assert!(!cage.contains(Cell(2, 2)));
}
#[test]
fn cells_returns_sorted_cells() {
let cage = Cage::commutative(4, domino(1, 1, 1, 2), Add, 5).unwrap();
assert_eq!(cage.cells(), vec![Cell(1, 1), Cell(1, 2)]);
}
#[test]
fn viable_counts_commutative_no_survivors_is_zero() {
let cage = Cage::commutative(4, domino(1, 1, 1, 2), Add, 5).unwrap();
let fills = vec![Fill::from(&[3]); 2];
assert_eq!(cage.viable_counts(&fills).unwrap(), (0, 0));
}
#[test]
fn viable_counts_non_commutative_no_survivors_is_zero() {
let cage = Cage::non_commutative(4, domino(1, 1, 1, 2), Subtract, 3).unwrap();
let fills = vec![Fill::from(&[2]); 2];
assert_eq!(cage.viable_counts(&fills).unwrap(), (0, 0));
}
#[test]
fn viable_counts_given_present_and_absent() {
let cage = Cage::given(Cell(1, 1), 3).unwrap();
assert_eq!(cage.viable_counts(&[Fill::from(&[1, 3])]).unwrap(), (1, 1));
assert_eq!(cage.viable_counts(&[Fill::from(&[1, 2])]).unwrap(), (0, 0));
}
#[test]
fn display_cage_shows_operation_and_cells() {
let cage = Cage::commutative(4, domino(1, 1, 1, 2), Add, 5).unwrap();
assert_eq!(cage.to_string(), "Cage(+5 (1, 1), (1, 2))");
}
#[test]
fn display_cage_operator_symbols() {
assert_eq!(CageOperator::Add.to_string(), "+");
assert_eq!(CageOperator::Subtract.to_string(), "−");
assert_eq!(CageOperator::Multiply.to_string(), "×");
assert_eq!(CageOperator::Divide.to_string(), "÷");
assert_eq!(CageOperator::Given.to_string(), "=");
}
#[test]
fn display_operation_given_omits_operator() {
assert_eq!(Operation::new(CageOperator::Given, 3).to_string(), "3");
assert_eq!(Operation::new(CageOperator::Divide, 2).to_string(), "÷2");
}
#[test]
#[should_panic(expected = "narrow must not be called")]
fn no_narrow_panics_when_narrowed() {
let poly = domino(1, 1, 1, 2);
let mdd = Mdd::new(4, 2, Add, 3, &collinear_groups(&poly)).unwrap();
let _ = NoNarrow(mdd).narrow(&[Fill::from(&[1]), Fill::all(4)]);
}
}