board_game/games/
sttt.rs

1use std::fmt;
2use std::fmt::{Debug, Display, Formatter};
3use std::ops::ControlFlow;
4
5use internal_iterator::InternalIterator;
6use itertools::Itertools;
7use rand::Rng;
8
9use crate::board::{
10    AllMovesIterator, Alternating, AvailableMovesIterator, Board, BoardDone, BoardMoves, BoardSymmetry, Outcome,
11    PlayError, Player,
12};
13use crate::symmetry::D4Symmetry;
14use crate::util::bits::{get_nth_set_bit, BitIter};
15use crate::util::iter::ClonableInternal;
16
17#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
18pub struct Coord(u8);
19
20#[derive(Clone, Eq, PartialEq, Hash)]
21pub struct STTTBoard {
22    grids: [u32; 9],
23    main_grid: u32,
24
25    last_move: Option<Coord>,
26    next_player: Player,
27    outcome: Option<Outcome>,
28
29    macro_mask: u32,
30    macro_open: u32,
31}
32
33impl Default for STTTBoard {
34    fn default() -> STTTBoard {
35        STTTBoard {
36            grids: [0; 9],
37            main_grid: 0,
38            last_move: None,
39            next_player: Player::A,
40            outcome: None,
41            macro_mask: STTTBoard::FULL_MASK,
42            macro_open: STTTBoard::FULL_MASK,
43        }
44    }
45}
46
47impl STTTBoard {
48    const FULL_MASK: u32 = 0b111_111_111;
49
50    pub fn tile(&self, coord: Coord) -> Option<Player> {
51        get_player(self.grids[coord.om() as usize], coord.os())
52    }
53
54    pub fn macr(&self, om: u8) -> Option<Player> {
55        debug_assert!(om < 9);
56        get_player(self.main_grid, om)
57    }
58
59    pub fn is_macro_open(&self, om: u8) -> bool {
60        debug_assert!(om < 9);
61        has_bit(self.macro_open, om)
62    }
63
64    /// Return the number of non-empty tiles.
65    pub fn count_tiles(&self) -> u32 {
66        self.grids.iter().map(|tile| tile.count_ones()).sum()
67    }
68
69    fn set_tile_and_update(&mut self, player: Player, coord: Coord) {
70        let om = coord.om();
71        let os = coord.os();
72        let p = 9 * player.index();
73
74        //set tile and macro, check win
75        let new_grid = self.grids[om as usize] | (1 << (os + p));
76        self.grids[om as usize] = new_grid;
77
78        let grid_win = is_win_grid((new_grid >> p) & STTTBoard::FULL_MASK);
79        if grid_win {
80            let new_main_grid = self.main_grid | (1 << (om + p));
81            self.main_grid = new_main_grid;
82
83            if is_win_grid((new_main_grid >> p) & STTTBoard::FULL_MASK) {
84                self.outcome = Some(Outcome::WonBy(player));
85            }
86        }
87
88        //update macro masks, remove bit from open and recalculate mask
89        if grid_win || new_grid.count_ones() == 9 {
90            self.macro_open &= !(1 << om);
91            if self.macro_open == 0 && self.outcome.is_none() {
92                self.outcome = Some(Outcome::Draw);
93            }
94        }
95        self.macro_mask = self.calc_macro_mask(os);
96    }
97
98    fn calc_macro_mask(&self, os: u8) -> u32 {
99        if has_bit(self.macro_open, os) {
100            1u32 << os
101        } else {
102            self.macro_open
103        }
104    }
105}
106
107impl Board for STTTBoard {
108    type Move = Coord;
109
110    fn next_player(&self) -> Player {
111        self.next_player
112    }
113
114    fn is_available_move(&self, mv: Self::Move) -> Result<bool, BoardDone> {
115        self.check_done()?;
116
117        let can_play_in_macro = has_bit(self.macro_mask, mv.om());
118        let micro_occupied = has_bit(compact_grid(self.grids[mv.om() as usize]), mv.os());
119
120        Ok(can_play_in_macro && !micro_occupied)
121    }
122
123    fn random_available_move(&self, rng: &mut impl Rng) -> Result<Self::Move, BoardDone> {
124        // TODO we can also implement size_hint and skip for the available move iterator,
125        //   then we don't need this complicated body any more
126
127        self.check_done()?;
128
129        let mut count = 0;
130        for om in BitIter::new(self.macro_mask) {
131            count += 9 - self.grids[om as usize].count_ones();
132        }
133
134        let mut index = rng.gen_range(0..count);
135
136        for om in BitIter::new(self.macro_mask) {
137            let grid = self.grids[om as usize];
138            let grid_count = 9 - grid.count_ones();
139
140            if index < grid_count {
141                let os = get_nth_set_bit(!compact_grid(grid), index);
142                return Ok(Coord::from_oo(om, os));
143            }
144
145            index -= grid_count;
146        }
147
148        unreachable!()
149    }
150
151    fn play(&mut self, mv: Self::Move) -> Result<(), PlayError> {
152        self.check_can_play(mv)?;
153
154        //do actual move
155        self.set_tile_and_update(self.next_player, mv);
156
157        //update for next player
158        self.last_move = Some(mv);
159        self.next_player = self.next_player.other();
160
161        Ok(())
162    }
163
164    fn outcome(&self) -> Option<Outcome> {
165        self.outcome
166    }
167
168    fn can_lose_after_move() -> bool {
169        false
170    }
171}
172
173impl Alternating for STTTBoard {}
174
175impl BoardSymmetry<STTTBoard> for STTTBoard {
176    type Symmetry = D4Symmetry;
177    type CanonicalKey = (u32, Option<Coord>, u32, u32);
178
179    fn map(&self, sym: D4Symmetry) -> STTTBoard {
180        let mut grids = [0; 9];
181        for oo in 0..9 {
182            grids[map_oo(sym, oo) as usize] = map_grid(sym, self.grids[oo as usize])
183        }
184
185        STTTBoard {
186            grids,
187            main_grid: map_grid(sym, self.main_grid),
188            last_move: self.last_move.map(|c| self.map_move(sym, c)),
189            next_player: self.next_player,
190            outcome: self.outcome,
191            macro_mask: map_grid(sym, self.macro_mask),
192            macro_open: map_grid(sym, self.macro_open),
193        }
194    }
195
196    fn map_move(&self, sym: D4Symmetry, mv: Coord) -> Coord {
197        Coord::from_oo(map_oo(sym, mv.om()), map_oo(sym, mv.os()))
198    }
199
200    fn canonical_key(&self) -> Self::CanonicalKey {
201        (self.main_grid, self.last_move, self.macro_mask, self.macro_open)
202    }
203}
204
205impl<'a> BoardMoves<'a, STTTBoard> for STTTBoard {
206    type AllMovesIterator = ClonableInternal<CoordIter>;
207    type AvailableMovesIterator = AvailableMovesIterator<'a, STTTBoard>;
208
209    fn all_possible_moves() -> Self::AllMovesIterator {
210        ClonableInternal::new(Coord::all())
211    }
212
213    fn available_moves(&'a self) -> Result<Self::AvailableMovesIterator, BoardDone> {
214        AvailableMovesIterator::new(self)
215    }
216}
217
218impl InternalIterator for AllMovesIterator<STTTBoard> {
219    type Item = Coord;
220
221    fn try_for_each<R, F>(self, f: F) -> ControlFlow<R>
222    where
223        F: FnMut(Self::Item) -> ControlFlow<R>,
224    {
225        Coord::all().try_for_each(f)
226    }
227}
228
229impl<'a> InternalIterator for AvailableMovesIterator<'a, STTTBoard> {
230    type Item = Coord;
231
232    fn try_for_each<R, F: FnMut(Self::Item) -> ControlFlow<R>>(self, mut f: F) -> ControlFlow<R> {
233        let board = self.board();
234        for om in BitIter::new(board.macro_mask) {
235            let free_grid = (!compact_grid(board.grids[om as usize])) & STTTBoard::FULL_MASK;
236            for os in BitIter::new(free_grid) {
237                f(Coord::from_oo(om, os))?;
238            }
239        }
240
241        ControlFlow::Continue(())
242    }
243}
244
245pub type CoordIter = std::iter::Map<std::ops::Range<u8>, fn(u8) -> Coord>;
246
247impl Coord {
248    pub fn all() -> CoordIter {
249        (0..81).map(Self::from_o)
250    }
251
252    pub fn all_yx() -> CoordIter {
253        (0..81).map(|i| Self::from_xy(i % 9, i / 9))
254    }
255
256    pub fn from_oo(om: u8, os: u8) -> Coord {
257        debug_assert!(om < 9);
258        debug_assert!(os < 9);
259        Coord(9 * om + os)
260    }
261
262    pub fn from_o(o: u8) -> Coord {
263        debug_assert!(o < 81);
264        Coord(o)
265    }
266
267    pub fn from_xy(x: u8, y: u8) -> Coord {
268        debug_assert!(x < 9 && y < 9);
269        Coord(((x / 3) + (y / 3) * 3) * 9 + ((x % 3) + (y % 3) * 3))
270    }
271
272    pub fn om(self) -> u8 {
273        self.0 / 9
274    }
275
276    pub fn os(self) -> u8 {
277        self.0 % 9
278    }
279
280    pub fn o(self) -> u8 {
281        9 * self.om() + self.os()
282    }
283
284    pub fn yx(self) -> u8 {
285        9 * self.y() + self.x()
286    }
287
288    pub fn x(self) -> u8 {
289        (self.om() % 3) * 3 + (self.os() % 3)
290    }
291
292    pub fn y(self) -> u8 {
293        (self.om() / 3) * 3 + (self.os() / 3)
294    }
295}
296
297impl Debug for Coord {
298    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
299        write!(f, "Coord({}, {})", self.om(), self.os())
300    }
301}
302
303impl Display for Coord {
304    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
305        write!(f, "({}, {})", self.om(), self.os())
306    }
307}
308
309fn map_oo(sym: D4Symmetry, oo: u8) -> u8 {
310    let (x, y) = sym.map_xy(oo % 3, oo / 3, 3);
311    x + y * 3
312}
313
314fn map_grid(sym: D4Symmetry, grid: u32) -> u32 {
315    // this could be implemented faster but it's not on a hot path
316    let mut result = 0;
317    for oo_input in 0..9 {
318        let oo_result = map_oo(sym, oo_input);
319        let get = (grid >> oo_input) & 0b1_000_000_001;
320        result |= get << oo_result;
321    }
322    result
323}
324
325fn is_win_grid(grid: u32) -> bool {
326    debug_assert!(has_mask(STTTBoard::FULL_MASK, grid));
327
328    const WIN_GRIDS: [u32; 16] = [
329        2155905152, 4286611584, 4210076288, 4293962368, 3435954304, 4291592320, 4277971584, 4294748800, 2863300736,
330        4294635760, 4210731648, 4294638320, 4008607872, 4294897904, 4294967295, 4294967295,
331    ];
332    has_bit(WIN_GRIDS[(grid / 32) as usize], (grid % 32) as u8)
333}
334
335fn has_bit(x: u32, i: u8) -> bool {
336    ((x >> i) & 1) != 0
337}
338
339fn has_mask(x: u32, mask: u32) -> bool {
340    x & mask == mask
341}
342
343fn compact_grid(grid: u32) -> u32 {
344    (grid | grid >> 9) & STTTBoard::FULL_MASK
345}
346
347fn get_player(grid: u32, index: u8) -> Option<Player> {
348    if has_bit(grid, index) {
349        Some(Player::A)
350    } else if has_bit(grid, index + 9) {
351        Some(Player::B)
352    } else {
353        None
354    }
355}
356
357fn symbol_from_tile(board: &STTTBoard, coord: Coord) -> char {
358    let is_last = Some(coord) == board.last_move;
359    let is_available = board.is_available_move(coord).unwrap_or(false);
360    let player = board.tile(coord);
361    symbol_from_tuple(is_available, is_last, player)
362}
363
364fn symbol_from_tuple(is_available: bool, is_last: bool, player: Option<Player>) -> char {
365    let tuple = (is_available, is_last, player);
366    match tuple {
367        (false, false, Some(Player::A)) => 'x',
368        (false, true, Some(Player::A)) => 'X',
369        (false, false, Some(Player::B)) => 'o',
370        (false, true, Some(Player::B)) => 'O',
371        (true, false, None) => '.',
372        (false, false, None) => ' ',
373        _ => unreachable!("Invalid tile state {:?}", tuple),
374    }
375}
376
377fn symbol_to_tuple(c: char) -> (bool, bool, Option<Player>) {
378    match c {
379        'x' => (false, false, Some(Player::A)),
380        'X' => (false, true, Some(Player::A)),
381        'o' => (false, false, Some(Player::B)),
382        'O' => (false, true, Some(Player::B)),
383        ' ' => (false, false, None),
384        '.' => (true, false, None),
385        _ => panic!("unexpected character '{}'", c),
386    }
387}
388
389impl fmt::Debug for STTTBoard {
390    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
391        write!(f, "STTTBoard({:?})", board_to_compact_string(self))
392    }
393}
394
395impl fmt::Display for STTTBoard {
396    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
397        for y in 0..9 {
398            if y == 3 || y == 6 {
399                writeln!(f, "---+---+---     +---+")?;
400            }
401
402            for x in 0..9 {
403                if x == 3 || x == 6 {
404                    write!(f, "|")?;
405                }
406                write!(f, "{}", symbol_from_tile(self, Coord::from_xy(x, y)))?;
407            }
408
409            if (3..6).contains(&y) {
410                write!(f, "     |")?;
411                let ym = y - 3;
412                for xm in 0..3 {
413                    let om = xm + 3 * ym;
414                    write!(f, "{}", symbol_from_tuple(self.is_macro_open(om), false, self.macr(om)))?;
415                }
416                write!(f, "|")?;
417            }
418
419            writeln!(f)?;
420        }
421
422        Ok(())
423    }
424}
425
426pub fn board_to_compact_string(board: &STTTBoard) -> String {
427    Coord::all().map(|coord| symbol_from_tile(board, coord)).join("")
428}
429
430pub fn board_from_compact_string(s: &str) -> STTTBoard {
431    assert!(s.chars().count() == 81, "compact string should have length 81");
432
433    let mut board = STTTBoard::default();
434    let mut last_move = None;
435
436    for (o, c) in s.chars().enumerate() {
437        let coord = Coord::from_o(o as u8);
438        let (_, last, player) = symbol_to_tuple(c);
439
440        if last {
441            assert!(last_move.is_none(), "Compact string cannot contain multiple last moves");
442            let player = player.expect("Last move must have been played by a player");
443            last_move = Some((player, coord));
444        }
445
446        if let Some(player) = player {
447            board.set_tile_and_update(player, coord);
448        }
449    }
450
451    if let Some((last_player, last_coord)) = last_move {
452        board.set_tile_and_update(last_player, last_coord);
453        board.last_move = Some(last_coord);
454        board.next_player = last_player.other()
455    }
456
457    board
458}