1#![warn(missing_docs)]
48#![allow(clippy::needless_range_loop)]
49
50use std::fmt;
51use std::str::FromStr;
52
53pub(crate) const FULL_MASK: u16 = 0x1FF;
57
58pub(crate) const ROW: [u8; 81] = generate_row();
59pub(crate) const COL: [u8; 81] = generate_col();
60pub(crate) const BOX: [u8; 81] = generate_box();
61
62const fn generate_row() -> [u8; 81] {
63 let mut a = [0u8; 81];
64 let mut i = 0;
65 while i < 81 { a[i] = (i / 9) as u8; i += 1; }
66 a
67}
68const fn generate_col() -> [u8; 81] {
69 let mut a = [0u8; 81];
70 let mut i = 0;
71 while i < 81 { a[i] = (i % 9) as u8; i += 1; }
72 a
73}
74const fn generate_box() -> [u8; 81] {
75 let mut a = [0u8; 81];
76 let mut i = 0;
77 while i < 81 { a[i] = ((i / 27) * 3 + (i % 9 / 3)) as u8; i += 1; }
78 a
79}
80
81#[derive(Debug, Clone, PartialEq, Eq)]
100pub enum ParseError {
101 InvalidLength(usize),
103 InvalidCharacter {
105 position: usize,
107 ch: char,
109 },
110 DuplicateDigit {
112 position: usize,
114 digit: u8,
116 },
117}
118
119impl fmt::Display for ParseError {
120 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121 match self {
122 ParseError::InvalidLength(n) =>
123 write!(f, "puzzle must be 81 characters, got {}", n),
124 ParseError::InvalidCharacter { position, ch } =>
125 write!(f, "invalid character {:?} at position {}", ch, position),
126 ParseError::DuplicateDigit { position, digit } =>
127 write!(f, "digit {} at position {} conflicts with its row/column/box",
128 digit, position),
129 }
130 }
131}
132
133impl std::error::Error for ParseError {}
134
135#[repr(C, align(64))]
172#[derive(Clone, Copy)]
173pub struct Sudoku {
174 pub cells: [u8; 81],
176 rows: [u16; 9],
177 cols: [u16; 9],
178 boxes: [u16; 9],
179 empty: [u8; 81],
180}
181
182impl Sudoku {
183 pub fn from_string(s: &str) -> Option<Self> {
203 s.parse().ok()
204 }
205
206 #[inline]
222 pub fn cells(&self) -> &[u8; 81] {
223 &self.cells
224 }
225
226 #[inline]
239 pub fn is_solved(&self) -> bool {
240 self.cells.iter().all(|&v| v != 0)
241 }
242
243 #[inline]
264 pub fn get(&self, r: usize, c: usize) -> u8 {
265 assert!(r < 9 && c < 9, "row and column must be 0..9");
266 self.cells[r * 9 + c]
267 }
268
269 pub fn solve(&mut self) -> bool {
291 let mut n = 0usize;
292 for i in 0..81 {
293 if self.cells[i] == 0 {
294 self.empty[n] = i as u8;
295 n += 1;
296 }
297 }
298 self.solve_recursive(n)
299 }
300
301 #[inline(always)]
302 fn get_mask(&self, idx: usize) -> u16 {
303 debug_assert!(idx < 81);
304 unsafe {
305 let r = *ROW.get_unchecked(idx) as usize;
306 let c = *COL.get_unchecked(idx) as usize;
307 let b = *BOX.get_unchecked(idx) as usize;
308 debug_assert!(r < 9 && c < 9 && b < 9);
309 !(*self.rows.get_unchecked(r)
310 | *self.cols.get_unchecked(c)
311 | *self.boxes.get_unchecked(b))
312 & FULL_MASK
313 }
314 }
315
316 fn solve_recursive(&mut self, num_empty: usize) -> bool {
317 if num_empty == 0 { return true; }
318
319 let mut best_i = 0;
320 let mut min_c = 10u32;
321 let mut best_mask = 0u16;
322
323 for i in 0..num_empty {
324 debug_assert!(i < 81);
325 let idx = unsafe { *self.empty.get_unchecked(i) as usize };
326 debug_assert!(idx < 81);
327 let mask = self.get_mask(idx);
328 let count = mask.count_ones();
329 if count == 0 { return false; }
330 if count < min_c {
331 min_c = count;
332 best_i = i;
333 best_mask = mask;
334 if count == 1 { break; }
335 }
336 }
337
338 let idx = self.empty[best_i] as usize;
339 let last = num_empty - 1;
340 let saved = self.empty[last];
341 self.empty[best_i] = saved;
342
343 let mut m = best_mask;
344 while m != 0 {
345 let bit = m & m.wrapping_neg();
346 m ^= bit;
347 let digit = (bit.trailing_zeros() + 1) as u8;
348 let r = ROW[idx] as usize;
349 let c = COL[idx] as usize;
350 let b = BOX[idx] as usize;
351
352 self.cells[idx] = digit;
353 self.rows[r] |= bit;
354 self.cols[c] |= bit;
355 self.boxes[b] |= bit;
356
357 if self.solve_recursive(last) { return true; }
358
359 self.rows[r] &= !bit;
360 self.cols[c] &= !bit;
361 self.boxes[b] &= !bit;
362 }
363
364 self.cells[idx] = 0;
365 self.empty[best_i] = idx as u8;
366 self.empty[last] = saved;
367 false
368 }
369
370 pub fn to_digit_string(&self) -> String {
384 self.cells.iter().map(|&v| (v + b'0') as char).collect()
385 }
386
387 pub fn print_grid(&self) {
389 for row in 0..9 {
390 if row % 3 == 0 && row != 0 {
391 println!("------+-------+------");
392 }
393 for col in 0..9 {
394 if col % 3 == 0 && col != 0 { print!("| "); }
395 let v = self.cells[row * 9 + col];
396 if v == 0 { print!(". "); } else { print!("{} ", v); }
397 }
398 println!();
399 }
400 }
401}
402
403impl FromStr for Sudoku {
406 type Err = ParseError;
407
408 fn from_str(s: &str) -> Result<Self, Self::Err> {
428 if s.len() != 81 {
429 return Err(ParseError::InvalidLength(s.len()));
430 }
431 let mut board = Sudoku {
432 cells: [0; 81],
433 rows: [0; 9],
434 cols: [0; 9],
435 boxes: [0; 9],
436 empty: [0; 81],
437 };
438 for (i, ch) in s.bytes().enumerate() {
439 let val = match ch {
440 b'0'..=b'9' => ch - b'0',
441 _ => return Err(ParseError::InvalidCharacter {
442 position: i,
443 ch: ch as char,
444 }),
445 };
446 if val != 0 {
447 let bit: u16 = 1 << (val - 1);
448 let r = ROW[i] as usize;
449 let c = COL[i] as usize;
450 let b = BOX[i] as usize;
451 if (board.rows[r] | board.cols[c] | board.boxes[b]) & bit != 0 {
452 return Err(ParseError::DuplicateDigit { position: i, digit: val });
453 }
454 board.cells[i] = val;
455 board.rows[r] |= bit;
456 board.cols[c] |= bit;
457 board.boxes[b] |= bit;
458 }
459 }
460 Ok(board)
461 }
462}
463
464impl fmt::Display for Sudoku {
477 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
478 for &v in &self.cells {
479 write!(f, "{}", v)?;
480 }
481 Ok(())
482 }
483}
484
485impl fmt::Debug for Sudoku {
487 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
488 write!(f, "Sudoku(\"{}\")", self)
489 }
490}
491
492impl PartialEq for Sudoku {
493 fn eq(&self, other: &Self) -> bool {
494 self.cells == other.cells
495 }
496}
497
498impl Eq for Sudoku {}
499
500impl Default for Sudoku {
514 fn default() -> Self {
515 Sudoku {
516 cells: [0; 81],
517 rows: [0; 9],
518 cols: [0; 9],
519 boxes: [0; 9],
520 empty: [0; 81],
521 }
522 }
523}
524
525#[cfg(test)]
529mod tests {
530 use super::*;
531
532 impl Sudoku {
535 fn bitboard_matches_cells(&self) -> bool {
536 let mut er = [0u16; 9];
537 let mut ec = [0u16; 9];
538 let mut eb = [0u16; 9];
539 for i in 0..81 {
540 let v = self.cells[i];
541 if v != 0 {
542 let bit = 1u16 << (v - 1);
543 er[ROW[i] as usize] |= bit;
544 ec[COL[i] as usize] |= bit;
545 eb[BOX[i] as usize] |= bit;
546 }
547 }
548 self.rows == er && self.cols == ec && self.boxes == eb
549 }
550 }
551
552 fn is_valid_solution(cells: &[u8; 81]) -> bool {
553 for r in 0..9 {
554 let mut m = 0u16;
555 for c in 0..9 {
556 let v = cells[r * 9 + c];
557 if v == 0 || v > 9 { return false; }
558 let bit = 1u16 << (v - 1);
559 if m & bit != 0 { return false; }
560 m |= bit;
561 }
562 if m != FULL_MASK { return false; }
563 }
564 for c in 0..9 {
565 let mut m = 0u16;
566 for r in 0..9 {
567 let v = cells[r * 9 + c];
568 let bit = 1u16 << (v - 1);
569 if m & bit != 0 { return false; }
570 m |= bit;
571 }
572 if m != FULL_MASK { return false; }
573 }
574 for br in 0..3 { for bc in 0..3 {
575 let mut m = 0u16;
576 for dr in 0..3 { for dc in 0..3 {
577 let v = cells[(br * 3 + dr) * 9 + (bc * 3 + dc)];
578 let bit = 1u16 << (v - 1);
579 if m & bit != 0 { return false; }
580 m |= bit;
581 }}
582 if m != FULL_MASK { return false; }
583 }}
584 true
585 }
586
587 fn assert_clues_preserved(puzzle: &str, s: &Sudoku) {
588 for (i, ch) in puzzle.bytes().enumerate() {
589 if ch != b'0' {
590 assert_eq!(s.cells[i], ch - b'0',
591 "Clue at pos {} (row {}, col {}) overwritten", i, i/9, i%9);
592 }
593 }
594 }
595
596 #[test]
599 fn l1_full_mask_is_9_bit_all_ones() {
600 assert_eq!(FULL_MASK, 0b1_1111_1111);
601 assert_eq!(FULL_MASK, 0x1FF);
602 assert_eq!(FULL_MASK, 511u16);
603 assert_eq!(FULL_MASK.count_ones(), 9);
604 assert_eq!(FULL_MASK & !0x1FF, 0);
605 }
606
607 #[test]
608 fn l1_row_table_is_correct() {
609 for i in 0..81usize { assert_eq!(ROW[i], (i / 9) as u8); assert!(ROW[i] < 9); }
610 let mut c = [0u8; 9]; for &r in ROW.iter() { c[r as usize] += 1; }
611 assert_eq!(c, [9u8; 9]);
612 }
613
614 #[test]
615 fn l1_col_table_is_correct() {
616 for i in 0..81usize { assert_eq!(COL[i], (i % 9) as u8); assert!(COL[i] < 9); }
617 let mut c = [0u8; 9]; for &v in COL.iter() { c[v as usize] += 1; }
618 assert_eq!(c, [9u8; 9]);
619 }
620
621 #[test]
622 fn l1_box_table_is_correct() {
623 for i in 0..81usize {
624 assert_eq!(BOX[i], ((i / 27) * 3 + (i % 9 / 3)) as u8);
625 assert!(BOX[i] < 9);
626 }
627 let mut c = [0u8; 9]; for &b in BOX.iter() { c[b as usize] += 1; }
628 assert_eq!(c, [9u8; 9]);
629 }
630
631 #[test]
632 fn l1_row_col_box_partition_is_consistent() {
633 for i in 0..81usize {
634 let r = ROW[i] as usize; let c = COL[i] as usize; let b = BOX[i] as usize;
635 assert_eq!(r * 9 + c, i);
636 assert_eq!(b, (r / 3) * 3 + (c / 3));
637 }
638 }
639
640 #[test]
643 fn l2_rejects_row_duplicate() {
644 let s = "110000000000000000000000000000000000000000000000000000000000000000000000000000000";
645 assert!(s.parse::<Sudoku>().is_err());
646 }
647
648 #[test]
649 fn l2_rejects_col_duplicate() {
650 let s = "000050000000050000000000000000000000000000000000000000000000000000000000000000000";
651 assert!(s.parse::<Sudoku>().is_err());
652 }
653
654 #[test]
655 fn l2_rejects_box_duplicate() {
656 let s = "000000000000000000000000000000300000000000000000003000000000000000000000000000000";
658 assert!(s.parse::<Sudoku>().is_err());
659 }
660
661 #[test]
662 fn l2_parse_error_variants() {
663 assert_eq!("12345".parse::<Sudoku>().unwrap_err(), ParseError::InvalidLength(5));
664 let non_digit = format!("x{}", "0".repeat(80));
665 assert_eq!(non_digit.parse::<Sudoku>().unwrap_err(),
666 ParseError::InvalidCharacter { position: 0, ch: 'x' });
667 let dup_row = format!("11{}", "0".repeat(79));
668 assert!(matches!(dup_row.parse::<Sudoku>().unwrap_err(),
669 ParseError::DuplicateDigit { digit: 1, .. }));
670 }
671
672 #[test]
673 fn l2_rejects_wrong_length() {
674 assert!("0".repeat(80).parse::<Sudoku>().is_err());
675 assert!("0".repeat(82).parse::<Sudoku>().is_err());
676 }
677
678 #[test]
679 fn l2_accepts_all_zeros() {
680 assert!("0".repeat(81).parse::<Sudoku>().is_ok());
681 }
682
683 #[test]
686 fn l3_al_escargot_solution_is_valid() {
687 let p = "100007060900020008080500000000305070020010000800000400004000000000460010030900005";
688 let mut s: Sudoku = p.parse().unwrap();
689 assert!(s.solve());
690 assert!(is_valid_solution(&s.cells));
691 assert_clues_preserved(p, &s);
692 assert!(s.is_solved());
693 }
694
695 #[test]
696 fn l3_hardest_2012_solution_is_valid() {
697 let p = "800000000003600000070090200050007000000045700000100030001000068008500010090000400";
698 let mut s: Sudoku = p.parse().unwrap();
699 assert!(s.solve());
700 assert!(is_valid_solution(&s.cells));
701 assert_clues_preserved(p, &s);
702 assert!(s.is_solved());
703 }
704
705 #[test]
706 fn l3_platinum_blonde_solution_is_valid() {
707 let p = "000000012000000003002300400001800005060000070004000600000050090000200001000000000";
708 let mut s: Sudoku = p.parse().unwrap();
709 assert!(s.solve());
710 assert!(is_valid_solution(&s.cells));
711 assert_clues_preserved(p, &s);
712 assert!(s.is_solved());
713 }
714
715 #[test]
718 fn l4_bitboard_consistent_after_parse() {
719 let s: Sudoku = "800000000003600000070090200050007000000045700000100030001000068008500010090000400"
720 .parse().unwrap();
721 assert!(s.bitboard_matches_cells());
722 }
723
724 #[test]
725 fn l4_bitboard_consistent_after_solve() {
726 let p = "100007060900020008080500000000305070020010000800000400004000000000460010030900005";
727 let mut s: Sudoku = p.parse().unwrap();
728 s.solve();
729 assert!(s.bitboard_matches_cells());
730 }
731
732 #[test]
733 fn l4_already_solved_board_returns_true() {
734 let solved = "123456789456789123789123456231564897564897231897231564312645978645978312978312645";
735 let mut s: Sudoku = solved.parse().unwrap();
736 let before = s.cells;
737 assert!(s.solve());
738 assert_eq!(s.cells, before);
739 }
740
741 #[test]
742 fn l4_unsolvable_board_returns_false_and_preserves_clues() {
743 let p = "012345678900000000000000000000000000000000000000000000000000000000000000000000000";
745 let mut s: Sudoku = p.parse().unwrap();
746 let snap: Vec<u8> = p.bytes().enumerate()
747 .filter(|(_, b)| *b != b'0').map(|(i, _)| s.cells[i]).collect();
748 assert!(!s.solve());
749 for (k, (i, _)) in p.bytes().enumerate().filter(|(_, b)| *b != b'0').enumerate() {
750 assert_eq!(s.cells[i], snap[k]);
751 }
752 }
753
754 #[test]
755 fn l4_solve_is_deterministic() {
756 let p = "000000012000000003002300400001800005060000070004000600000050090000200001000000000";
757 let mut a: Sudoku = p.parse().unwrap();
758 let mut b: Sudoku = p.parse().unwrap();
759 a.solve(); b.solve();
760 assert_eq!(a.cells, b.cells);
761 }
762
763 #[test]
764 fn l4_display_round_trips() {
765 let p = "800000000003600000070090200050007000000045700000100030001000068008500010090000400";
766 let s: Sudoku = p.parse().unwrap();
767 assert_eq!(format!("{}", s), p);
768 assert_eq!(s.to_digit_string(), p);
769 }
770
771 #[test]
772 fn l4_default_is_empty_board() {
773 let s = Sudoku::default();
774 assert!(s.cells.iter().all(|&v| v == 0));
775 }
776
777 #[test]
780 fn l5_batch_hard_puzzles() {
781 let corpus = [
782 ("Al Escargot",
783 "100007060900020008080500000000305070020010000800000400004000000000460010030900005"),
784 ("Hardest 2012",
785 "800000000003600000070090200050007000000045700000100030001000068008500010090000400"),
786 ("Platinum Blonde",
787 "000000012000000003002300400001800005060000070004000600000050090000200001000000000"),
788 ("Norvig hard",
789 "400000805030000000000700000020000060000080400000010000000603070500200000104000000"),
790 ("Classic easy",
791 "003020600900305001001806400008102900700000008006708200002609500800203009005010300"),
792 ];
793 for (name, puzzle) in &corpus {
794 let mut s: Sudoku = puzzle.parse()
795 .unwrap_or_else(|e| panic!("{}: parse error: {}", name, e));
796 assert!(s.solve(), "{}: no solution", name);
797 assert!(is_valid_solution(&s.cells), "{}: invalid solution", name);
798 assert_clues_preserved(puzzle, &s);
799 assert!(s.is_solved(), "{}: cells not all filled", name);
800 assert!(s.bitboard_matches_cells(), "{}: bitboard mismatch", name);
801 }
802 }
803}