#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Masks {
row_masks: [u16; 9],
col_masks: [u16; 9],
box_masks: [u16; 9],
}
impl Masks {
pub(super) fn new() -> Self {
Masks {
row_masks: [0; 9],
col_masks: [0; 9],
box_masks: [0; 9],
}
}
pub(super) fn get_box_idx(r: usize, c: usize) -> usize {
(r / 3) * 3 + (c / 3)
}
pub(super) fn add_number(&mut self, r: usize, c: usize, num: u8) {
let bit_to_set = 1 << (num - 1);
let box_idx = Self::get_box_idx(r, c);
self.row_masks[r] |= bit_to_set;
self.col_masks[c] |= bit_to_set;
self.box_masks[box_idx] |= bit_to_set;
}
pub(super) fn remove_number(&mut self, r: usize, c: usize, num: u8) {
let bit_to_unset = 1 << (num - 1);
let box_idx = Self::get_box_idx(r, c);
self.row_masks[r] &= !bit_to_unset;
self.col_masks[c] &= !bit_to_unset;
self.box_masks[box_idx] &= !bit_to_unset;
}
pub fn is_safe(&self, r: usize, c: usize, num: u8) -> bool {
let bit_to_check = 1 << (num - 1);
let box_idx = Self::get_box_idx(r, c);
(self.row_masks[r] & bit_to_check == 0)
&& (self.col_masks[c] & bit_to_check == 0)
&& (self.box_masks[box_idx] & bit_to_check == 0)
}
pub(super) fn compute_candidates_mask_for_cell(&self, r: usize, c: usize) -> u16 {
let row_mask = self.row_masks[r];
let col_mask = self.col_masks[c];
let box_mask = self.box_masks[Self::get_box_idx(r, c)];
let used = row_mask | col_mask | box_mask;
!used & 0x1FF
}
}
#[cfg(test)]
mod tests {
use super::*;
fn bit(n: u8) -> u16 {
1u16 << (n - 1)
}
fn bits(nums: &[u8]) -> u16 {
nums.iter().map(|&n| bit(n)).fold(0, |acc, b| acc | b)
}
#[test]
fn test_new_initializes_empty_masks() {
let masks = Masks::new();
assert_eq!(masks.row_masks, [0; 9]);
assert_eq!(masks.col_masks, [0; 9]);
assert_eq!(masks.box_masks, [0; 9]);
}
#[test]
fn test_get_box_idx_top_left_box() {
assert_eq!(Masks::get_box_idx(0, 0), 0);
assert_eq!(Masks::get_box_idx(1, 1), 0);
assert_eq!(Masks::get_box_idx(2, 2), 0);
}
#[test]
fn test_get_box_idx_middle_box() {
assert_eq!(Masks::get_box_idx(3, 3), 4);
assert_eq!(Masks::get_box_idx(4, 4), 4);
assert_eq!(Masks::get_box_idx(5, 5), 4);
}
#[test]
fn test_get_box_idx_bottom_right_box() {
assert_eq!(Masks::get_box_idx(6, 6), 8);
assert_eq!(Masks::get_box_idx(7, 7), 8);
assert_eq!(Masks::get_box_idx(8, 8), 8);
}
#[test]
fn test_get_box_idx_various_boxes() {
assert_eq!(Masks::get_box_idx(0, 3), 1); assert_eq!(Masks::get_box_idx(3, 0), 3); assert_eq!(Masks::get_box_idx(0, 8), 2); assert_eq!(Masks::get_box_idx(8, 0), 6); }
#[test]
fn test_add_number_single_cell() {
let mut masks = Masks::new();
masks.add_number(0, 0, 1);
assert_eq!(masks.row_masks[0], bit(1));
assert_eq!(masks.col_masks[0], bit(1));
assert_eq!(masks.box_masks[0], bit(1));
}
#[test]
fn test_add_number_multiple_in_same_row() {
let mut masks = Masks::new();
masks.add_number(0, 0, 1);
masks.add_number(0, 1, 5);
assert_eq!(masks.row_masks[0], bits(&[1, 5]));
assert_eq!(masks.col_masks[0], bit(1));
assert_eq!(masks.col_masks[1], bit(5));
assert_eq!(masks.box_masks[0], bits(&[1, 5]));
}
#[test]
fn test_add_number_to_different_box() {
let mut masks = Masks::new();
masks.add_number(8, 8, 9);
assert_eq!(masks.row_masks[8], bit(9));
assert_eq!(masks.col_masks[8], bit(9));
assert_eq!(masks.box_masks[8], bit(9));
}
#[test]
fn test_add_number_already_present_no_change() {
let mut masks = Masks::new();
masks.add_number(0, 0, 1);
let initial_row_0 = masks.row_masks[0];
masks.add_number(0, 0, 1); assert_eq!(masks.row_masks[0], initial_row_0);
}
#[test]
fn test_remove_number_single_value_from_cell() {
let mut masks = Masks::new();
masks.add_number(0, 0, 1); masks.remove_number(0, 0, 1); assert_eq!(masks.row_masks[0], 0);
assert_eq!(masks.col_masks[0], 0);
assert_eq!(masks.box_masks[0], 0);
}
#[test]
fn test_remove_number_from_shared_row() {
let mut masks = Masks::new();
masks.add_number(0, 0, 1);
masks.add_number(0, 1, 5);
masks.remove_number(0, 0, 1);
assert_eq!(masks.row_masks[0], bit(5)); assert_eq!(masks.col_masks[0], 0); assert_eq!(masks.col_masks[1], bit(5)); assert_eq!(masks.box_masks[0], bit(5)); }
#[test]
fn test_remove_number_not_present_no_change() {
let mut masks = Masks::new();
masks.add_number(0, 0, 1);
let initial_row_0 = masks.row_masks[0];
masks.remove_number(0, 0, 3); assert_eq!(masks.row_masks[0], initial_row_0);
}
#[test]
fn test_is_safe_on_empty_board_always_true() {
let masks = Masks::new();
assert!(masks.is_safe(0, 0, 1)); assert!(masks.is_safe(8, 8, 9)); }
#[test]
fn test_is_safe_conflict_in_row() {
let mut masks = Masks::new();
masks.add_number(0, 0, 1); assert!(!masks.is_safe(0, 1, 1)); assert!(masks.is_safe(0, 1, 2)); }
#[test]
fn test_is_safe_conflict_in_column() {
let mut masks = Masks::new();
masks.add_number(0, 0, 1); assert!(!masks.is_safe(1, 0, 1)); assert!(masks.is_safe(1, 0, 2)); }
#[test]
fn test_is_safe_conflict_in_box() {
let mut masks = Masks::new();
masks.add_number(1, 1, 1); assert!(!masks.is_safe(0, 0, 1)); assert!(masks.is_safe(0, 0, 2)); }
#[test]
fn test_is_safe_conflict_with_current_cell_value() {
let mut masks = Masks::new();
masks.add_number(0, 0, 1);
assert!(!masks.is_safe(0, 0, 1)); assert!(masks.is_safe(0, 0, 2)); }
#[test]
fn test_compute_candidates_empty_cell_all_available() {
let masks = Masks::new();
let candidates = masks.compute_candidates_mask_for_cell(0, 0);
assert_eq!(candidates, 0x1FF); }
#[test]
fn test_compute_candidates_row_has_1_to_8_only_9_available() {
let mut masks = Masks::new();
masks.row_masks[0] = bits(&[1, 2, 3, 4, 5, 6, 7, 8]); let candidates = masks.compute_candidates_mask_for_cell(0, 0);
assert_eq!(candidates, bit(9));
}
#[test]
fn test_compute_candidates_col_has_1_to_8_only_9_available() {
let mut masks = Masks::new();
masks.col_masks[0] = bits(&[1, 2, 3, 4, 5, 6, 7, 8]); let candidates = masks.compute_candidates_mask_for_cell(0, 0);
assert_eq!(candidates, bit(9));
}
#[test]
fn test_compute_candidates_box_has_1_to_8_only_9_available() {
let mut masks = Masks::new();
masks.box_masks[Masks::get_box_idx(1, 1)] = bits(&[1, 2, 3, 4, 5, 6, 7, 8]); let candidates = masks.compute_candidates_mask_for_cell(1, 1);
assert_eq!(candidates, bit(9));
}
#[test]
fn test_compute_candidates_mixed_restrictions() {
let mut masks = Masks::new();
masks.row_masks[0] = bits(&[1, 2]);
masks.col_masks[0] = bits(&[3, 4]);
masks.box_masks[0] = bits(&[5, 6]);
let candidates = masks.compute_candidates_mask_for_cell(0, 0);
assert_eq!(candidates, bits(&[7, 8, 9])); }
#[test]
fn test_compute_candidates_no_candidates_left() {
let mut masks = Masks::new();
masks.row_masks[0] = 0x1FF; let candidates = masks.compute_candidates_mask_for_cell(0, 0);
assert_eq!(candidates, 0); }
}