Skip to main content

latin_sampler/
square.rs

1/// A Latin square of order `n`.
2///
3/// A Latin square is an `n x n` array with symbols `{0..n-1}` such that
4/// each row and each column is a permutation of `{0..n-1}`.
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct LatinSquare {
7    n: usize,
8    cells: Vec<u8>,
9}
10
11impl LatinSquare {
12    /// Creates the cyclic Latin square of order `n`: `L[r][c] = (r + c) mod n`.
13    ///
14    /// # Panics
15    /// Panics if `n < 2` or `n > 255`.
16    pub(crate) fn new_cyclic(n: usize) -> Self {
17        assert!((2..=255).contains(&n), "n must be in range 2..=255");
18        let cells = (0..n)
19            .flat_map(|r| (0..n).map(move |c| ((r + c) % n) as u8))
20            .collect();
21        Self { n, cells }
22    }
23
24    /// Returns the order of the Latin square.
25    pub fn n(&self) -> usize {
26        self.n
27    }
28
29    /// Returns the value at position `(r, c)`.
30    ///
31    /// # Panics
32    /// Panics if `r >= n` or `c >= n`.
33    pub fn get(&self, r: usize, c: usize) -> u8 {
34        assert!(r < self.n && c < self.n, "index out of bounds");
35        self.cells[r * self.n + c]
36    }
37
38    /// Sets the value at position `(r, c)` without checking the Latin property.
39    pub(crate) fn set_unchecked(&mut self, r: usize, c: usize, v: u8) {
40        self.cells[r * self.n + c] = v;
41    }
42
43    /// Returns the cells as a flat slice in row-major order.
44    ///
45    /// The cell at position (r, c) is at index `r * n + c`.
46    pub fn cells(&self) -> &[u8] {
47        &self.cells
48    }
49}
50
51#[cfg(feature = "serde")]
52impl serde::Serialize for LatinSquare {
53    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
54    where
55        S: serde::Serializer,
56    {
57        use serde::ser::SerializeSeq;
58        let mut seq = serializer.serialize_seq(Some(self.n))?;
59        for r in 0..self.n {
60            let row: Vec<u8> = (0..self.n).map(|c| self.cells[r * self.n + c]).collect();
61            seq.serialize_element(&row)?;
62        }
63        seq.end()
64    }
65}
66
67impl LatinSquare {
68    /// Returns true if this is a valid Latin square.
69    ///
70    /// This is a test-only helper for validation. The Latin property is an
71    /// invariant enforced by construction and moves.
72    #[cfg(test)]
73    pub(crate) fn is_latin(&self) -> bool {
74        let n = self.n;
75        let mut seen = vec![false; n];
76        // Check rows
77        for r in 0..n {
78            seen.fill(false);
79            for c in 0..n {
80                let v = self.get(r, c) as usize;
81                if v >= n || seen[v] {
82                    return false;
83                }
84                seen[v] = true;
85            }
86        }
87        // Check columns
88        for c in 0..n {
89            seen.fill(false);
90            for r in 0..n {
91                let v = self.get(r, c) as usize;
92                if v >= n || seen[v] {
93                    return false;
94                }
95                seen[v] = true;
96            }
97        }
98        true
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    #[test]
107    fn cyclic_is_latin() {
108        for n in 2..=10 {
109            let sq = LatinSquare::new_cyclic(n);
110            assert!(
111                sq.is_latin(),
112                "cyclic square of order {} should be Latin",
113                n
114            );
115        }
116    }
117}
118
119#[cfg(all(test, feature = "serde"))]
120mod serde_tests {
121    use super::*;
122
123    #[test]
124    fn serialize_cyclic_3x3() {
125        let sq = LatinSquare::new_cyclic(3);
126        let json = serde_json::to_string(&sq).unwrap();
127        assert_eq!(json, "[[0,1,2],[1,2,0],[2,0,1]]");
128    }
129
130    #[test]
131    fn serialize_cyclic_various_sizes() {
132        for n in 2..=10 {
133            let sq = LatinSquare::new_cyclic(n);
134            let json = serde_json::to_string(&sq).unwrap();
135            let parsed: Vec<Vec<u8>> = serde_json::from_str(&json).unwrap();
136            assert_eq!(parsed.len(), n);
137            for row in &parsed {
138                assert_eq!(row.len(), n);
139            }
140        }
141    }
142}