use std::collections::HashMap;
pub struct GobangSolver {
board: Vec<Vec<i32>>,
n: usize,
}
impl GobangSolver {
pub fn new(board: Vec<Vec<i32>>) -> Self {
let n = board.len();
Self { board, n }
}
pub fn find_four_in_line(&self) -> Option<[[i32; 2]; 2]> {
for line in self.iterate_lines() {
if line.len() < self.n {
continue;
}
let elements: Vec<i32> = line.iter().map(|&(r, c)| self.board[r][c]).collect();
let freq = Self::count_freq(&elements);
if !freq.values().any(|&count| count == self.n - 1) {
continue;
}
if freq.get(&0) == Some(&(self.n - 1)) {
continue;
}
let correct_num = freq
.iter()
.find(|(&num, &count)| count == self.n - 1 && num != 0)
.map(|(&num, _)| num);
if let Some(correct_num) = correct_num {
let zero_idx = elements.iter().position(|&x| x == 0);
if let Some(zero_idx) = zero_idx {
let fill_pos = line[zero_idx];
if let Some(remove_pos) = self.find_remove_candidate(correct_num, &line) {
return Some([
[remove_pos.0 as i32, remove_pos.1 as i32],
[fill_pos.0 as i32, fill_pos.1 as i32],
]);
}
}
}
}
None
}
fn iterate_lines(&self) -> impl Iterator<Item = Vec<(usize, usize)>> + '_ {
let n = self.n;
let rows = (0..n).map(move |row| (0..n).map(|col| (row, col)).collect::<Vec<_>>());
let cols = (0..n).map(move |col| (0..n).map(|row| (row, col)).collect::<Vec<_>>());
let main_diag_1 = (0..n).map(move |start_row| {
(0..(n - start_row))
.map(|i| (start_row + i, i))
.collect::<Vec<_>>()
});
let main_diag_2 = (1..n).map(move |start_col| {
(0..(n - start_col))
.map(|i| (i, start_col + i))
.collect::<Vec<_>>()
});
let anti_diag_1 = (0..n).map(move |start_row| {
(0..=start_row)
.map(|i| (start_row - i, i))
.collect::<Vec<_>>()
});
let anti_diag_2 = (1..n).map(move |start_col| {
(0..(n - start_col))
.map(|i| (n - 1 - i, start_col + i))
.collect::<Vec<_>>()
});
rows.chain(cols)
.chain(main_diag_1)
.chain(main_diag_2)
.chain(anti_diag_1)
.chain(anti_diag_2)
}
fn count_freq(elements: &[i32]) -> HashMap<i32, usize> {
let mut freq = HashMap::new();
for &num in elements {
*freq.entry(num).or_insert(0) += 1;
}
freq
}
fn find_remove_candidate(
&self,
target: i32,
exclude: &[(usize, usize)],
) -> Option<(usize, usize)> {
let exclude_set: std::collections::HashSet<_> = exclude.iter().cloned().collect();
for r in 0..self.n {
for c in 0..self.n {
if !exclude_set.contains(&(r, c)) && self.board[r][c] == target {
return Some((r, c));
}
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gobang_solver_basic() {
let board = vec![
vec![1, 1, 1, 0, 1],
vec![2, 2, 2, 2, 1], vec![3, 3, 3, 3, 3],
vec![4, 4, 4, 4, 4],
vec![5, 5, 5, 5, 5],
];
let solver = GobangSolver::new(board);
let result = solver.find_four_in_line();
assert!(result.is_some());
let [[remove_r, remove_c], [fill_r, fill_c]] = result.unwrap();
assert_eq!(fill_r, 0);
assert_eq!(fill_c, 3);
}
#[test]
fn test_count_freq() {
let elements = vec![1, 1, 1, 0, 2];
let freq = GobangSolver::count_freq(&elements);
assert_eq!(freq.get(&1), Some(&3));
assert_eq!(freq.get(&0), Some(&1));
assert_eq!(freq.get(&2), Some(&1));
}
#[test]
fn test_empty_board() {
let board = vec![vec![0; 5]; 5];
let solver = GobangSolver::new(board);
let result = solver.find_four_in_line();
assert!(result.is_none());
}
}