#![allow(
clippy::many_single_char_names, // r/c/v/i/j/k are conventional for Latin-square indices
)]
use crate::N;
use rand::{Rng, RngExt};
fn pick_one_from_line(rng: &mut impl Rng, n: usize, line: impl Fn(usize) -> i8) -> usize {
let mut ones = [0usize; 2];
let mut count = 0;
for x in 0..n {
if line(x) == 1 {
ones[count] = x;
count += 1;
}
}
ones[rng.random_range(0..count)]
}
pub fn generate_latin_square(n: usize, rng: &mut impl Rng) -> Vec<Vec<N>> {
if n == 1 {
return vec![vec![1]];
}
let mut m: Vec<Vec<Vec<i8>>> = vec![vec![vec![0i8; n]; n]; n];
for r in 0..n {
for c in 0..n {
m[r][c][(r + c) % n] = 1;
}
}
let mut improper: Option<(usize, usize, usize)> = None;
let target_moves = 6 * n * n * n;
let mut moves = 0usize;
while moves < target_moves || improper.is_some() {
let (i, j, k) = improper.unwrap_or_else(|| {
loop {
let r = rng.random_range(0..n);
let c = rng.random_range(0..n);
let v = rng.random_range(0..n);
if m[r][c][v] == 0 {
break (r, c, v);
}
}
});
let ip = pick_one_from_line(rng, n, |x| m[x][j][k]);
let jp = pick_one_from_line(rng, n, |x| m[i][x][k]);
let kp = pick_one_from_line(rng, n, |x| m[i][j][x]);
m[i][j][k] += 1;
m[ip][j][k] -= 1;
m[i][jp][k] -= 1;
m[i][j][kp] -= 1;
m[ip][jp][k] += 1;
m[ip][j][kp] += 1;
m[i][jp][kp] += 1;
m[ip][jp][kp] -= 1;
improper = (m[ip][jp][kp] == -1).then_some((ip, jp, kp));
moves += 1;
}
(0..n)
.map(|r| {
(0..n)
.map(|c| {
let v = (0..n).position(|v| m[r][c][v] == 1).unwrap_or(0);
#[allow(clippy::cast_possible_truncation)]
{
(v + 1) as N
}
})
.collect()
})
.collect()
}
#[cfg(test)]
mod tests {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use std::collections::{HashMap, HashSet};
use super::*;
fn validate_latin_square(ls: &[Vec<N>]) -> bool {
let n = ls.len();
#[allow(clippy::cast_possible_truncation)]
let expected: HashSet<N> = (1..=(n as N)).collect();
for row in ls {
if row.iter().copied().collect::<HashSet<N>>() != expected {
return false;
}
}
for c in 0..n {
let col: HashSet<N> = ls.iter().map(|r| r[c]).collect();
if col != expected {
return false;
}
}
true
}
#[test]
fn generate_4x4_returns_valid_square() {
let mut rng = ChaCha8Rng::seed_from_u64(42);
let ls = generate_latin_square(4, &mut rng);
assert!(validate_latin_square(&ls));
}
#[test]
fn generate_1x1_returns_valid_square() {
let mut rng = ChaCha8Rng::seed_from_u64(7);
let ls = generate_latin_square(1, &mut rng);
let expected: Vec<Vec<N>> = vec![vec![1]];
assert_eq!(ls, expected);
assert!(validate_latin_square(&ls));
}
#[test]
fn validate_rejects_invalid() {
let ls: Vec<Vec<N>> = vec![vec![1, 1, 3], vec![2, 3, 1], vec![3, 2, 2]];
assert!(!validate_latin_square(&ls));
}
#[test]
fn generates_all_twelve_reduced_3x3_squares() {
let mut rng = ChaCha8Rng::seed_from_u64(42);
let mut counts: HashMap<Vec<Vec<N>>, usize> = HashMap::new();
for _ in 0..1200 {
let ls = generate_latin_square(3, &mut rng);
*counts.entry(ls).or_insert(0) += 1;
}
assert_eq!(counts.len(), 12);
for (grid, &count) in &counts {
assert!(count >= 10, "grid {grid:?} only appeared {count} times");
}
}
#[test]
fn validate_rejects_invalid_column() {
let ls: Vec<Vec<N>> = vec![vec![1, 2, 3], vec![1, 3, 2], vec![1, 2, 3]];
assert!(!validate_latin_square(&ls));
}
}